properties) throws TableAlreadyExistsException, NoSuchNamespaceException;
+
+ /**
+ * Apply a set of {@link TableChange changes} to a table in the catalog.
+ *
+ * Implementations may reject the requested changes. If any change is rejected, none of the
+ * changes should be applied to the table.
+ *
+ * If the catalog supports views and contains a view for the identifier and not a table, this
+ * must throw {@link NoSuchTableException}.
+ *
+ * @param ident a table identifier
+ * @param changes changes to apply to the table
+ * @return updated metadata for the table
+ * @throws NoSuchTableException If the table doesn't exist or is a view
+ * @throws IllegalArgumentException If any change is rejected by the implementation.
+ */
+ Table alterTable(
+ Identifier ident,
+ TableChange... changes) throws NoSuchTableException;
+
+ /**
+ * Drop a table in the catalog.
+ *
+ * If the catalog supports views and contains a view for the identifier and not a table, this
+ * must not drop the view and must return false.
+ *
+ * @param ident a table identifier
+ * @return true if a table was deleted, false if no table exists for the identifier
+ */
+ boolean dropTable(Identifier ident);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java
new file mode 100644
index 0000000000000..9b87e676d9b2d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java
@@ -0,0 +1,366 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2;
+
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * TableChange subclasses represent requested changes to a table. These are passed to
+ * {@link TableCatalog#alterTable}. For example,
+ *
+ * import TableChange._
+ * val catalog = Catalogs.load(name)
+ * catalog.asTableCatalog.alterTable(ident,
+ * addColumn("x", IntegerType),
+ * renameColumn("a", "b"),
+ * deleteColumn("c")
+ * )
+ *
+ */
+public interface TableChange {
+
+ /**
+ * Create a TableChange for setting a table property.
+ *
+ * If the property already exists, it will be replaced with the new value.
+ *
+ * @param property the property name
+ * @param value the new property value
+ * @return a TableChange for the addition
+ */
+ static TableChange setProperty(String property, String value) {
+ return new SetProperty(property, value);
+ }
+
+ /**
+ * Create a TableChange for removing a table property.
+ *
+ * If the property does not exist, the change will succeed.
+ *
+ * @param property the property name
+ * @return a TableChange for the addition
+ */
+ static TableChange removeProperty(String property) {
+ return new RemoveProperty(property);
+ }
+
+ /**
+ * Create a TableChange for adding an optional column.
+ *
+ * If the field already exists, the change will result in an {@link IllegalArgumentException}.
+ * If the new field is nested and its parent does not exist or is not a struct, the change will
+ * result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the new column
+ * @param dataType the new column's data type
+ * @return a TableChange for the addition
+ */
+ static TableChange addColumn(String[] fieldNames, DataType dataType) {
+ return new AddColumn(fieldNames, dataType, true, null);
+ }
+
+ /**
+ * Create a TableChange for adding a column.
+ *
+ * If the field already exists, the change will result in an {@link IllegalArgumentException}.
+ * If the new field is nested and its parent does not exist or is not a struct, the change will
+ * result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the new column
+ * @param dataType the new column's data type
+ * @param isNullable whether the new column can contain null
+ * @return a TableChange for the addition
+ */
+ static TableChange addColumn(String[] fieldNames, DataType dataType, boolean isNullable) {
+ return new AddColumn(fieldNames, dataType, isNullable, null);
+ }
+
+ /**
+ * Create a TableChange for adding a column.
+ *
+ * If the field already exists, the change will result in an {@link IllegalArgumentException}.
+ * If the new field is nested and its parent does not exist or is not a struct, the change will
+ * result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the new column
+ * @param dataType the new column's data type
+ * @param isNullable whether the new column can contain null
+ * @param comment the new field's comment string
+ * @return a TableChange for the addition
+ */
+ static TableChange addColumn(
+ String[] fieldNames,
+ DataType dataType,
+ boolean isNullable,
+ String comment) {
+ return new AddColumn(fieldNames, dataType, isNullable, comment);
+ }
+
+ /**
+ * Create a TableChange for renaming a field.
+ *
+ * The name is used to find the field to rename. The new name will replace the leaf field name.
+ * For example, renameColumn(["a", "b", "c"], "x") should produce column a.b.x.
+ *
+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames the current field names
+ * @param newName the new name
+ * @return a TableChange for the rename
+ */
+ static TableChange renameColumn(String[] fieldNames, String newName) {
+ return new RenameColumn(fieldNames, newName);
+ }
+
+ /**
+ * Create a TableChange for updating the type of a field that is nullable.
+ *
+ * The field names are used to find the field to update.
+ *
+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the column to update
+ * @param newDataType the new data type
+ * @return a TableChange for the update
+ */
+ static TableChange updateColumnType(String[] fieldNames, DataType newDataType) {
+ return new UpdateColumnType(fieldNames, newDataType, true);
+ }
+
+ /**
+ * Create a TableChange for updating the type of a field.
+ *
+ * The field names are used to find the field to update.
+ *
+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the column to update
+ * @param newDataType the new data type
+ * @return a TableChange for the update
+ */
+ static TableChange updateColumnType(
+ String[] fieldNames,
+ DataType newDataType,
+ boolean isNullable) {
+ return new UpdateColumnType(fieldNames, newDataType, isNullable);
+ }
+
+ /**
+ * Create a TableChange for updating the comment of a field.
+ *
+ * The name is used to find the field to update.
+ *
+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the column to update
+ * @param newComment the new comment
+ * @return a TableChange for the update
+ */
+ static TableChange updateColumnComment(String[] fieldNames, String newComment) {
+ return new UpdateColumnComment(fieldNames, newComment);
+ }
+
+ /**
+ * Create a TableChange for deleting a field.
+ *
+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}.
+ *
+ * @param fieldNames field names of the column to delete
+ * @return a TableChange for the delete
+ */
+ static TableChange deleteColumn(String[] fieldNames) {
+ return new DeleteColumn(fieldNames);
+ }
+
+ /**
+ * A TableChange to set a table property.
+ *
+ * If the property already exists, it must be replaced with the new value.
+ */
+ final class SetProperty implements TableChange {
+ private final String property;
+ private final String value;
+
+ private SetProperty(String property, String value) {
+ this.property = property;
+ this.value = value;
+ }
+
+ public String property() {
+ return property;
+ }
+
+ public String value() {
+ return value;
+ }
+ }
+
+ /**
+ * A TableChange to remove a table property.
+ *
+ * If the property does not exist, the change should succeed.
+ */
+ final class RemoveProperty implements TableChange {
+ private final String property;
+
+ private RemoveProperty(String property) {
+ this.property = property;
+ }
+
+ public String property() {
+ return property;
+ }
+ }
+
+ /**
+ * A TableChange to add a field.
+ *
+ * If the field already exists, the change must result in an {@link IllegalArgumentException}.
+ * If the new field is nested and its parent does not exist or is not a struct, the change must
+ * result in an {@link IllegalArgumentException}.
+ */
+ final class AddColumn implements TableChange {
+ private final String[] fieldNames;
+ private final DataType dataType;
+ private final boolean isNullable;
+ private final String comment;
+
+ private AddColumn(String[] fieldNames, DataType dataType, boolean isNullable, String comment) {
+ this.fieldNames = fieldNames;
+ this.dataType = dataType;
+ this.isNullable = isNullable;
+ this.comment = comment;
+ }
+
+ public String[] fieldNames() {
+ return fieldNames;
+ }
+
+ public DataType dataType() {
+ return dataType;
+ }
+
+ public boolean isNullable() {
+ return isNullable;
+ }
+
+ public String comment() {
+ return comment;
+ }
+ }
+
+ /**
+ * A TableChange to rename a field.
+ *
+ * The name is used to find the field to rename. The new name will replace the leaf field name.
+ * For example, renameColumn("a.b.c", "x") should produce column a.b.x.
+ *
+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}.
+ */
+ final class RenameColumn implements TableChange {
+ private final String[] fieldNames;
+ private final String newName;
+
+ private RenameColumn(String[] fieldNames, String newName) {
+ this.fieldNames = fieldNames;
+ this.newName = newName;
+ }
+
+ public String[] fieldNames() {
+ return fieldNames;
+ }
+
+ public String newName() {
+ return newName;
+ }
+ }
+
+ /**
+ * A TableChange to update the type of a field.
+ *
+ * The field names are used to find the field to update.
+ *
+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}.
+ */
+ final class UpdateColumnType implements TableChange {
+ private final String[] fieldNames;
+ private final DataType newDataType;
+ private final boolean isNullable;
+
+ private UpdateColumnType(String[] fieldNames, DataType newDataType, boolean isNullable) {
+ this.fieldNames = fieldNames;
+ this.newDataType = newDataType;
+ this.isNullable = isNullable;
+ }
+
+ public String[] fieldNames() {
+ return fieldNames;
+ }
+
+ public DataType newDataType() {
+ return newDataType;
+ }
+
+ public boolean isNullable() {
+ return isNullable;
+ }
+ }
+
+ /**
+ * A TableChange to update the comment of a field.
+ *
+ * The field names are used to find the field to update.
+ *
+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}.
+ */
+ final class UpdateColumnComment implements TableChange {
+ private final String[] fieldNames;
+ private final String newComment;
+
+ private UpdateColumnComment(String[] fieldNames, String newComment) {
+ this.fieldNames = fieldNames;
+ this.newComment = newComment;
+ }
+
+ public String[] fieldNames() {
+ return fieldNames;
+ }
+
+ public String newComment() {
+ return newComment;
+ }
+ }
+
+ /**
+ * A TableChange to delete a field.
+ *
+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}.
+ */
+ final class DeleteColumn implements TableChange {
+ private final String[] fieldNames;
+
+ private DeleteColumn(String[] fieldNames) {
+ this.fieldNames = fieldNames;
+ }
+
+ public String[] fieldNames() {
+ return fieldNames;
+ }
+ }
+
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java
similarity index 73%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java
index 43bdcca70cb09..1e2aca9556df4 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java
@@ -15,12 +15,17 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2;
+package org.apache.spark.sql.catalog.v2.expressions;
-import org.apache.spark.annotation.Evolving;
+import org.apache.spark.annotation.Experimental;
/**
- * TODO: remove it when we finish the API refactor for streaming write side.
+ * Base class of the public logical expression API.
*/
-@Evolving
-public interface DataSourceV2 {}
+@Experimental
+public interface Expression {
+ /**
+ * Format the expression as a human readable SQL-like string.
+ */
+ String describe();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java
new file mode 100644
index 0000000000000..d8e49beb0bca5
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2.expressions;
+
+import java.util.Arrays;
+
+import scala.collection.JavaConverters;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * Helper methods to create logical transforms to pass into Spark.
+ */
+@Experimental
+public class Expressions {
+ private Expressions() {
+ }
+
+ /**
+ * Create a logical transform for applying a named transform.
+ *
+ * This transform can represent applying any named transform.
+ *
+ * @param name the transform name
+ * @param args expression arguments to the transform
+ * @return a logical transform
+ */
+ public static Transform apply(String name, Expression... args) {
+ return LogicalExpressions.apply(name,
+ JavaConverters.asScalaBufferConverter(Arrays.asList(args)).asScala());
+ }
+
+ /**
+ * Create a named reference expression for a column.
+ *
+ * @param name a column name
+ * @return a named reference for the column
+ */
+ public static NamedReference column(String name) {
+ return LogicalExpressions.reference(name);
+ }
+
+ /**
+ * Create a literal from a value.
+ *
+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API
+ * for the literal's {@link DataType SQL data type}.
+ *
+ * @param value a value
+ * @param the JVM type of the value
+ * @return a literal expression for the value
+ */
+ public static Literal literal(T value) {
+ return LogicalExpressions.literal(value);
+ }
+
+ /**
+ * Create a bucket transform for one or more columns.
+ *
+ * This transform represents a logical mapping from a value to a bucket id in [0, numBuckets)
+ * based on a hash of the value.
+ *
+ * The name reported by transforms created with this method is "bucket".
+ *
+ * @param numBuckets the number of output buckets
+ * @param columns input columns for the bucket transform
+ * @return a logical bucket transform with name "bucket"
+ */
+ public static Transform bucket(int numBuckets, String... columns) {
+ return LogicalExpressions.bucket(numBuckets,
+ JavaConverters.asScalaBufferConverter(Arrays.asList(columns)).asScala());
+ }
+
+ /**
+ * Create an identity transform for a column.
+ *
+ * This transform represents a logical mapping from a value to itself.
+ *
+ * The name reported by transforms created with this method is "identity".
+ *
+ * @param column an input column
+ * @return a logical identity transform with name "identity"
+ */
+ public static Transform identity(String column) {
+ return LogicalExpressions.identity(column);
+ }
+
+ /**
+ * Create a yearly transform for a timestamp or date column.
+ *
+ * This transform represents a logical mapping from a timestamp or date to a year, such as 2018.
+ *
+ * The name reported by transforms created with this method is "years".
+ *
+ * @param column an input timestamp or date column
+ * @return a logical yearly transform with name "years"
+ */
+ public static Transform years(String column) {
+ return LogicalExpressions.years(column);
+ }
+
+ /**
+ * Create a monthly transform for a timestamp or date column.
+ *
+ * This transform represents a logical mapping from a timestamp or date to a month, such as
+ * 2018-05.
+ *
+ * The name reported by transforms created with this method is "months".
+ *
+ * @param column an input timestamp or date column
+ * @return a logical monthly transform with name "months"
+ */
+ public static Transform months(String column) {
+ return LogicalExpressions.months(column);
+ }
+
+ /**
+ * Create a daily transform for a timestamp or date column.
+ *
+ * This transform represents a logical mapping from a timestamp or date to a date, such as
+ * 2018-05-13.
+ *
+ * The name reported by transforms created with this method is "days".
+ *
+ * @param column an input timestamp or date column
+ * @return a logical daily transform with name "days"
+ */
+ public static Transform days(String column) {
+ return LogicalExpressions.days(column);
+ }
+
+ /**
+ * Create an hourly transform for a timestamp column.
+ *
+ * This transform represents a logical mapping from a timestamp to a date and hour, such as
+ * 2018-05-13, hour 19.
+ *
+ * The name reported by transforms created with this method is "hours".
+ *
+ * @param column an input timestamp column
+ * @return a logical hourly transform with name "hours"
+ */
+ public static Transform hours(String column) {
+ return LogicalExpressions.hours(column);
+ }
+
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java
similarity index 56%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java
index b7fa3f24a238c..e41bcf9000c52 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java
@@ -15,20 +15,28 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2;
+package org.apache.spark.sql.catalog.v2.expressions;
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.reader.Scan;
-import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.types.DataType;
/**
- * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with
- * continuous mode.
+ * Represents a constant literal value in the public expression API.
*
- * If a {@link Table} implements this interface, the
- * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that
- * builds {@link Scan} with {@link Scan#toContinuousStream(String)} implemented.
- *
+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API for
+ * the literal's {@link DataType SQL data type}.
+ *
+ * @param the JVM type of a value held by the literal
*/
-@Evolving
-public interface SupportsContinuousRead extends SupportsRead { }
+@Experimental
+public interface Literal extends Expression {
+ /**
+ * Returns the literal value.
+ */
+ T value();
+
+ /**
+ * Returns the SQL data type of the literal.
+ */
+ DataType dataType();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java
new file mode 100644
index 0000000000000..c71ffbe70651f
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2.expressions;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * Represents a field or column reference in the public logical expression API.
+ */
+@Experimental
+public interface NamedReference extends Expression {
+ /**
+ * Returns the referenced field name as an array of String parts.
+ *
+ * Each string in the returned array represents a field name.
+ */
+ String[] fieldNames();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java
similarity index 54%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java
index 9408e323f9da1..c85e0c412f1ab 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java
@@ -15,20 +15,30 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2;
+package org.apache.spark.sql.catalog.v2.expressions;
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.reader.Scan;
-import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
+import org.apache.spark.annotation.Experimental;
/**
- * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with
- * micro-batch mode.
+ * Represents a transform function in the public logical expression API.
*
- * If a {@link Table} implements this interface, the
- * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that
- * builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented.
- *
+ * For example, the transform date(ts) is used to derive a date value from a timestamp column. The
+ * transform name is "date" and its argument is a reference to the "ts" column.
*/
-@Evolving
-public interface SupportsMicroBatchRead extends SupportsRead { }
+@Experimental
+public interface Transform extends Expression {
+ /**
+ * Returns the transform function name.
+ */
+ String name();
+
+ /**
+ * Returns all field references in the transform arguments.
+ */
+ NamedReference[] references();
+
+ /**
+ * Returns the arguments passed to the transform function.
+ */
+ Expression[] arguments();
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
similarity index 89%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
index c00abd9b685b5..d27fbfdd14617 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java
@@ -20,12 +20,12 @@
import org.apache.spark.annotation.Evolving;
/**
- * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
+ * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to
* propagate session configs with the specified key-prefix to all data source operations in this
* session.
*/
@Evolving
-public interface SessionConfigSupport extends DataSourceV2 {
+public interface SessionConfigSupport extends TableProvider {
/**
* Key prefix of the session configs to propagate, which is usually the data source name. Spark
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
similarity index 76%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
index 5031c71c0fd4d..826fa2f8a0720 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
@@ -19,13 +19,14 @@
import org.apache.spark.sql.sources.v2.reader.Scan;
import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
/**
- * An internal base interface of mix-in interfaces for readable {@link Table}. This adds
- * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch,
- * or continuous processing.
+ * A mix-in interface of {@link Table}, to indicate that it's readable. This adds
+ * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch,
+ * micro-batch, or continuous processing.
*/
-interface SupportsRead extends Table {
+public interface SupportsRead extends Table {
/**
* Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this
@@ -34,5 +35,5 @@ interface SupportsRead extends Table {
* @param options The options for reading, which is an immutable case-insensitive
* string-to-string map.
*/
- ScanBuilder newScanBuilder(DataSourceOptions options);
+ ScanBuilder newScanBuilder(CaseInsensitiveStringMap options);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java
similarity index 77%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java
index ecdfe20730254..c52e54569dc0c 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java
@@ -19,17 +19,18 @@
import org.apache.spark.sql.sources.v2.writer.BatchWrite;
import org.apache.spark.sql.sources.v2.writer.WriteBuilder;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
/**
- * An internal base interface of mix-in interfaces for writable {@link Table}. This adds
- * {@link #newWriteBuilder(DataSourceOptions)} that is used to create a write
+ * A mix-in interface of {@link Table}, to indicate that it's writable. This adds
+ * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write
* for batch or streaming.
*/
-interface SupportsWrite extends Table {
+public interface SupportsWrite extends Table {
/**
* Returns a {@link WriteBuilder} which can be used to create {@link BatchWrite}. Spark will call
* this method to configure each data source write.
*/
- WriteBuilder newWriteBuilder(DataSourceOptions options);
+ WriteBuilder newWriteBuilder(CaseInsensitiveStringMap options);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java
similarity index 65%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java
index 08664859b8de2..482d3c22e2306 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java
@@ -18,18 +18,24 @@
package org.apache.spark.sql.sources.v2;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.catalog.v2.expressions.Transform;
import org.apache.spark.sql.types.StructType;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+
/**
* An interface representing a logical structured data set of a data source. For example, the
* implementation can be a directory on the file system, a topic of Kafka, or a table in the
* catalog, etc.
*
- * This interface can mixin the following interfaces to support different operations:
- *
- *
- * - {@link SupportsBatchRead}: this table can be read in batch queries.
- *
+ * This interface can mixin the following interfaces to support different operations, like
+ * {@code SupportsRead}.
+ *
+ * The default implementation of {@link #partitioning()} returns an empty array of partitions, and
+ * the default implementation of {@link #properties()} returns an empty map. These should be
+ * overridden by implementations that support partitioning and table properties.
*/
@Evolving
public interface Table {
@@ -45,4 +51,23 @@ public interface Table {
* empty schema can be returned here.
*/
StructType schema();
+
+ /**
+ * Returns the physical partitioning of this table.
+ */
+ default Transform[] partitioning() {
+ return new Transform[0];
+ }
+
+ /**
+ * Returns the string map of table properties.
+ */
+ default Map properties() {
+ return Collections.emptyMap();
+ }
+
+ /**
+ * Returns the set of capabilities for this table.
+ */
+ Set capabilities();
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java
new file mode 100644
index 0000000000000..c44a12b174f4c
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * Capabilities that can be provided by a {@link Table} implementation.
+ *
+ * Tables use {@link Table#capabilities()} to return a set of capabilities. Each capability signals
+ * to Spark that the table supports a feature identified by the capability. For example, returning
+ * {@code BATCH_READ} allows Spark to read from the table using a batch scan.
+ */
+@Experimental
+public enum TableCapability {
+ /**
+ * Signals that the table supports reads in batch execution mode.
+ */
+ BATCH_READ,
+
+ /**
+ * Signals that the table supports reads in micro-batch streaming execution mode.
+ */
+ MICRO_BATCH_READ,
+
+ /**
+ * Signals that the table supports reads in continuous streaming execution mode.
+ */
+ CONTINUOUS_READ,
+
+ /**
+ * Signals that the table supports append writes in batch execution mode.
+ *
+ * Tables that return this capability must support appending data and may also support additional
+ * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and
+ * {@link #OVERWRITE_DYNAMIC}.
+ */
+ BATCH_WRITE,
+
+ /**
+ * Signals that the table supports append writes in streaming execution mode.
+ *
+ * Tables that return this capability must support appending data and may also support additional
+ * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and
+ * {@link #OVERWRITE_DYNAMIC}.
+ */
+ STREAMING_WRITE,
+
+ /**
+ * Signals that the table can be truncated in a write operation.
+ *
+ * Truncating a table removes all existing rows.
+ *
+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsTruncate}.
+ */
+ TRUNCATE,
+
+ /**
+ * Signals that the table can replace existing data that matches a filter with appended data in
+ * a write operation.
+ *
+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}.
+ */
+ OVERWRITE_BY_FILTER,
+
+ /**
+ * Signals that the table can dynamically replace existing data partitions with appended data in
+ * a write operation.
+ *
+ * See {@code org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}.
+ */
+ OVERWRITE_DYNAMIC,
+
+ /**
+ * Signals that the table accepts input of any schema in a write operation.
+ */
+ ACCEPT_ANY_SCHEMA
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java
similarity index 79%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java
index 855d5efe0c69f..1d37ff042bd33 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java
@@ -18,19 +18,22 @@
package org.apache.spark.sql.sources.v2;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.DataSourceRegister;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
/**
* The base interface for v2 data sources which don't have a real catalog. Implementations must
* have a public, 0-arg constructor.
*
+ * Note that, TableProvider can only apply data operations to existing tables, like read, append,
+ * delete, and overwrite. It does not support the operations that require metadata changes, like
+ * create/drop tables.
+ *
* The major responsibility of this interface is to return a {@link Table} for read/write.
*
*/
@Evolving
-// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely.
-public interface TableProvider extends DataSourceV2 {
+public interface TableProvider {
/**
* Return a {@link Table} instance to do read/write with user-specified options.
@@ -38,7 +41,7 @@ public interface TableProvider extends DataSourceV2 {
* @param options the user-specified options that can identify a table, e.g. file path, Kafka
* topic name, etc. It's an immutable case-insensitive string-to-string map.
*/
- Table getTable(DataSourceOptions options);
+ Table getTable(CaseInsensitiveStringMap options);
/**
* Return a {@link Table} instance to do read/write with user-specified schema and options.
@@ -51,14 +54,8 @@ public interface TableProvider extends DataSourceV2 {
* @param schema the user-specified schema.
* @throws UnsupportedOperationException
*/
- default Table getTable(DataSourceOptions options, StructType schema) {
- String name;
- if (this instanceof DataSourceRegister) {
- name = ((DataSourceRegister) this).shortName();
- } else {
- name = this.getClass().getName();
- }
+ default Table getTable(CaseInsensitiveStringMap options, StructType schema) {
throw new UnsupportedOperationException(
- name + " source does not support user-specified schema");
+ this.getClass().getSimpleName() + " source does not support user-specified schema");
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
similarity index 90%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
index 25ab06eee42e0..ac4f38287a24d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java
@@ -21,10 +21,8 @@
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousStream;
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.sources.v2.SupportsBatchRead;
-import org.apache.spark.sql.sources.v2.SupportsContinuousRead;
-import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead;
import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.TableCapability;
/**
* A logical representation of a data source scan. This interface is used to provide logical
@@ -33,8 +31,8 @@
* This logical representation is shared between batch scan, micro-batch streaming scan and
* continuous streaming scan. Data sources must implement the corresponding methods in this
* interface, to match what the table promises to support. For example, {@link #toBatch()} must be
- * implemented, if the {@link Table} that creates this {@link Scan} implements
- * {@link SupportsBatchRead}.
+ * implemented, if the {@link Table} that creates this {@link Scan} returns
+ * {@link TableCapability#BATCH_READ} support in its {@link Table#capabilities()}.
*
*/
@Evolving
@@ -62,7 +60,8 @@ default String description() {
/**
* Returns the physical representation of this scan for batch query. By default this method throws
* exception, data sources must overwrite this method to provide an implementation, if the
- * {@link Table} that creates this scan implements {@link SupportsBatchRead}.
+ * {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} in its
+ * {@link Table#capabilities()}.
*
* @throws UnsupportedOperationException
*/
@@ -73,8 +72,8 @@ default Batch toBatch() {
/**
* Returns the physical representation of this scan for streaming query with micro-batch mode. By
* default this method throws exception, data sources must overwrite this method to provide an
- * implementation, if the {@link Table} that creates this scan implements
- * {@link SupportsMicroBatchRead}.
+ * implementation, if the {@link Table} that creates this scan returns
+ * {@link TableCapability#MICRO_BATCH_READ} support in its {@link Table#capabilities()}.
*
* @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
* recovery. Data streams for the same logical source in the same query
@@ -89,8 +88,8 @@ default MicroBatchStream toMicroBatchStream(String checkpointLocation) {
/**
* Returns the physical representation of this scan for streaming query with continuous mode. By
* default this method throws exception, data sources must overwrite this method to provide an
- * implementation, if the {@link Table} that creates this scan implements
- * {@link SupportsContinuousRead}.
+ * implementation, if the {@link Table} that creates this scan returns
+ * {@link TableCapability#CONTINUOUS_READ} support in its {@link Table#capabilities()}.
*
* @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure
* recovery. Data streams for the same logical source in the same query
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
similarity index 92%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
index 296d3e47e732b..f10fd884daabe 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java
@@ -29,6 +29,9 @@ public interface SupportsPushDownFilters extends ScanBuilder {
/**
* Pushes down filters, and returns filters that need to be evaluated after scanning.
+ *
+ * Rows should be returned from the data source if and only if all of the filters match. That is,
+ * filters must be interpreted as ANDed together.
*/
Filter[] pushFilters(Filter[] filters);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
similarity index 80%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
index a06671383ac5f..1d34fdd1c28ab 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java
@@ -25,13 +25,9 @@
* During execution, offsets provided by the data source implementation will be logged and used as
* restart checkpoints. Each source should provide an offset implementation which the source can use
* to reconstruct a position in the stream up to which data has been seen/processed.
- *
- * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to
- * maintain compatibility with DataSource V1 APIs. This extension will be removed once we
- * get rid of V1 completely.
*/
@Evolving
-public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset {
+public abstract class Offset {
/**
* A JSON-serialized representation of an Offset that is
* used for saving offsets to the offset log.
@@ -49,9 +45,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of
*/
@Override
public boolean equals(Object obj) {
- if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) {
- return this.json()
- .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json());
+ if (obj instanceof Offset) {
+ return this.json().equals(((Offset) obj).json());
} else {
return false;
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java
similarity index 93%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java
index 30f38ce37c401..2068a84fc6bb1 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java
@@ -18,7 +18,6 @@
package org.apache.spark.sql.sources.v2.reader.streaming;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.execution.streaming.BaseStreamingSource;
/**
* The base interface representing a readable data stream in a Spark streaming query. It's
@@ -28,7 +27,7 @@
* {@link MicroBatchStream} and {@link ContinuousStream}.
*/
@Evolving
-public interface SparkDataStream extends BaseStreamingSource {
+public interface SparkDataStream {
/**
* Returns the initial offset for a streaming query to start reading from. Note that the
@@ -50,4 +49,9 @@ public interface SparkDataStream extends BaseStreamingSource {
* equal to `end` and will only request offsets greater than `end` in the future.
*/
void commit(Offset end);
+
+ /**
+ * Stop this source and free any resources it has allocated.
+ */
+ void stop();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java
new file mode 100644
index 0000000000000..8058964b662bd
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.writer;
+
+/**
+ * Write builder trait for tables that support dynamic partition overwrite.
+ *
+ * A write that dynamically overwrites partitions removes all existing data in each logical
+ * partition for which the write will commit new data. Any existing logical partition for which the
+ * write does not contain data will remain unchanged.
+ *
+ * This is provided to implement SQL compatible with Hive table operations but is not recommended.
+ * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data.
+ */
+public interface SupportsDynamicOverwrite extends WriteBuilder {
+ /**
+ * Configures a write to dynamically replace partitions with data committed in the write.
+ *
+ * @return this write builder for method chaining
+ */
+ WriteBuilder overwriteDynamicPartitions();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java
new file mode 100644
index 0000000000000..b443b3c3aeb4a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.writer;
+
+import org.apache.spark.sql.sources.AlwaysTrue$;
+import org.apache.spark.sql.sources.Filter;
+
+/**
+ * Write builder trait for tables that support overwrite by filter.
+ *
+ * Overwriting data by filter will delete any data that matches the filter and replace it with data
+ * that is committed in the write.
+ */
+public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
+ /**
+ * Configures a write to replace data matching the filters with data committed in the write.
+ *
+ * Rows must be deleted from the data source if and only if all of the filters match. That is,
+ * filters must be interpreted as ANDed together.
+ *
+ * @param filters filters used to match data to overwrite
+ * @return this write builder for method chaining
+ */
+ WriteBuilder overwrite(Filter[] filters);
+
+ @Override
+ default WriteBuilder truncate() {
+ return overwrite(new Filter[] { AlwaysTrue$.MODULE$ });
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java
similarity index 67%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java
index c4295f2371877..69c2ba5e01a49 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java
@@ -17,10 +17,16 @@
package org.apache.spark.sql.sources.v2.writer;
-import org.apache.spark.sql.SaveMode;
-
-// A temporary mixin trait for `WriteBuilder` to support `SaveMode`. Will be removed before
-// Spark 3.0 when all the new write operators are finished. See SPARK-26356 for more details.
-public interface SupportsSaveMode extends WriteBuilder {
- WriteBuilder mode(SaveMode mode);
+/**
+ * Write builder trait for tables that support truncation.
+ *
+ * Truncation removes all data in a table and replaces it with data that is committed in the write.
+ */
+public interface SupportsTruncate extends WriteBuilder {
+ /**
+ * Configures a write to replace all existing data with data committed in the write.
+ *
+ * @return this write builder for method chaining
+ */
+ WriteBuilder truncate();
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java
similarity index 83%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java
index e861c72af9e68..bfe41f5e8dfb5 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java
@@ -18,8 +18,8 @@
package org.apache.spark.sql.sources.v2.writer;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.SupportsBatchWrite;
import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite;
import org.apache.spark.sql.types.StructType;
/**
@@ -57,13 +57,16 @@ default WriteBuilder withInputDataSchema(StructType schema) {
/**
* Returns a {@link BatchWrite} to write data to batch source. By default this method throws
* exception, data sources must overwrite this method to provide an implementation, if the
- * {@link Table} that creates this scan implements {@link SupportsBatchWrite}.
- *
- * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode,
- * to indicate that no writing is needed. We can clean it up after removing
- * {@link SupportsSaveMode}.
+ * {@link Table} that creates this write returns {@link TableCapability#BATCH_WRITE} support in
+ * its {@link Table#capabilities()}.
*/
default BatchWrite buildForBatch() {
- throw new UnsupportedOperationException("Batch scans are not supported");
+ throw new UnsupportedOperationException(getClass().getName() +
+ " does not support batch write");
+ }
+
+ default StreamingWrite buildForStreaming() {
+ throw new UnsupportedOperationException(getClass().getName() +
+ " does not support streaming write");
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
similarity index 94%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
index 6334c8f643098..23e8580c404d4 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java
@@ -20,12 +20,12 @@
import java.io.Serializable;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport;
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite;
/**
* A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side
* as the input parameter of {@link BatchWrite#commit(WriterCommitMessage[])} or
- * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}.
+ * {@link StreamingWrite#commit(long, WriterCommitMessage[])}.
*
* This is an empty interface, data sources should define their own message class and use it when
* generating messages at executor side and handling the messages at driver side.
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
similarity index 96%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
index 7d3d21cb2b637..af2f03c9d4192 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java
@@ -26,7 +26,7 @@
/**
* A factory of {@link DataWriter} returned by
- * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating
+ * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating
* and initializing the actual data writer at executor side.
*
* Note that, the writer factory will be serialized and sent to executors, then the data writer
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java
similarity index 73%
rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java
index 84cfbf2dda483..5617f1cdc0efc 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java
@@ -22,13 +22,26 @@
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage;
/**
- * An interface that defines how to write the data to data source for streaming processing.
+ * An interface that defines how to write the data to data source in streaming queries.
*
- * Streaming queries are divided into intervals of data called epochs, with a monotonically
- * increasing numeric ID. This writer handles commits and aborts for each successive epoch.
+ * The writing procedure is:
+ * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to
+ * all the partitions of the input data(RDD).
+ * 2. For each epoch in each partition, create the data writer, and write the data of the epoch in
+ * the partition with this writer. If all the data are written successfully, call
+ * {@link DataWriter#commit()}. If exception happens during the writing, call
+ * {@link DataWriter#abort()}.
+ * 3. If writers in all partitions of one epoch are successfully committed, call
+ * {@link #commit(long, WriterCommitMessage[])}. If some writers are aborted, or the job failed
+ * with an unknown reason, call {@link #abort(long, WriterCommitMessage[])}.
+ *
+ * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should
+ * do it manually in their Spark applications if they want to retry.
+ *
+ * Please refer to the documentation of commit/abort methods for detailed specifications.
*/
@Evolving
-public interface StreamingWriteSupport {
+public interface StreamingWrite {
/**
* Creates a writer factory which will be serialized and sent to executors.
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java
new file mode 100644
index 0000000000000..da41346d7ce71
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java
@@ -0,0 +1,181 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util;
+
+import org.apache.spark.annotation.Experimental;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Case-insensitive map of string keys to string values.
+ *
+ * This is used to pass options to v2 implementations to ensure consistent case insensitivity.
+ *
+ * Methods that return keys in this map, like {@link #entrySet()} and {@link #keySet()}, return
+ * keys converted to lower case. This map doesn't allow null key.
+ */
+@Experimental
+public class CaseInsensitiveStringMap implements Map {
+ private final Logger logger = LoggerFactory.getLogger(CaseInsensitiveStringMap.class);
+
+ private String unsupportedOperationMsg = "CaseInsensitiveStringMap is read-only.";
+
+ public static CaseInsensitiveStringMap empty() {
+ return new CaseInsensitiveStringMap(new HashMap<>(0));
+ }
+
+ private final Map original;
+
+ private final Map delegate;
+
+ public CaseInsensitiveStringMap(Map originalMap) {
+ original = new HashMap<>(originalMap);
+ delegate = new HashMap<>(originalMap.size());
+ for (Map.Entry entry : originalMap.entrySet()) {
+ String key = toLowerCase(entry.getKey());
+ if (delegate.containsKey(key)) {
+ logger.warn("Converting duplicated key " + entry.getKey() +
+ " into CaseInsensitiveStringMap.");
+ }
+ delegate.put(key, entry.getValue());
+ }
+ }
+
+ @Override
+ public int size() {
+ return delegate.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return delegate.isEmpty();
+ }
+
+ private String toLowerCase(Object key) {
+ return key.toString().toLowerCase(Locale.ROOT);
+ }
+
+ @Override
+ public boolean containsKey(Object key) {
+ return delegate.containsKey(toLowerCase(key));
+ }
+
+ @Override
+ public boolean containsValue(Object value) {
+ return delegate.containsValue(value);
+ }
+
+ @Override
+ public String get(Object key) {
+ return delegate.get(toLowerCase(key));
+ }
+
+ @Override
+ public String put(String key, String value) {
+ throw new UnsupportedOperationException(unsupportedOperationMsg);
+ }
+
+ @Override
+ public String remove(Object key) {
+ throw new UnsupportedOperationException(unsupportedOperationMsg);
+ }
+
+ @Override
+ public void putAll(Map extends String, ? extends String> m) {
+ throw new UnsupportedOperationException(unsupportedOperationMsg);
+ }
+
+ @Override
+ public void clear() {
+ throw new UnsupportedOperationException(unsupportedOperationMsg);
+ }
+
+ @Override
+ public Set keySet() {
+ return delegate.keySet();
+ }
+
+ @Override
+ public Collection values() {
+ return delegate.values();
+ }
+
+ @Override
+ public Set> entrySet() {
+ return delegate.entrySet();
+ }
+
+ /**
+ * Returns the boolean value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive.
+ */
+ public boolean getBoolean(String key, boolean defaultValue) {
+ String value = get(key);
+ // We can't use `Boolean.parseBoolean` here, as it returns false for invalid strings.
+ if (value == null) {
+ return defaultValue;
+ } else if (value.equalsIgnoreCase("true")) {
+ return true;
+ } else if (value.equalsIgnoreCase("false")) {
+ return false;
+ } else {
+ throw new IllegalArgumentException(value + " is not a boolean string.");
+ }
+ }
+
+ /**
+ * Returns the integer value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive.
+ */
+ public int getInt(String key, int defaultValue) {
+ String value = get(key);
+ return value == null ? defaultValue : Integer.parseInt(value);
+ }
+
+ /**
+ * Returns the long value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive.
+ */
+ public long getLong(String key, long defaultValue) {
+ String value = get(key);
+ return value == null ? defaultValue : Long.parseLong(value);
+ }
+
+ /**
+ * Returns the double value to which the specified key is mapped,
+ * or defaultValue if there is no mapping for the key. The key match is case-insensitive.
+ */
+ public double getDouble(String key, double defaultValue) {
+ String value = get(key);
+ return value == null ? defaultValue : Double.parseDouble(value);
+ }
+
+ /**
+ * Returns the original case-sensitive map.
+ */
+ public Map asCaseSensitiveMap() {
+ return Collections.unmodifiableMap(original);
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
similarity index 99%
rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 906e9bc26ef53..07d17ee14ce23 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -23,7 +23,7 @@
import org.apache.arrow.vector.holders.NullableVarCharHolder;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.execution.arrow.ArrowUtils;
+import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.UTF8String;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
new file mode 100644
index 0000000000000..9f917ea11d72a
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
@@ -0,0 +1,280 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.vectorized;
+
+import java.util.*;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+/**
+ * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this
+ * batch so that Spark can access the data row by row. Instance of it is meant to be reused during
+ * the entire data loading process.
+ */
+@Evolving
+public final class ColumnarBatch {
+ private int numRows;
+ private final ColumnVector[] columns;
+
+ // Staging row returned from `getRow`.
+ private final ColumnarBatchRow row;
+
+ /**
+ * Called to close all the columns in this batch. It is not valid to access the data after
+ * calling this. This must be called at the end to clean up memory allocations.
+ */
+ public void close() {
+ for (ColumnVector c: columns) {
+ c.close();
+ }
+ }
+
+ /**
+ * Returns an iterator over the rows in this batch.
+ */
+ public Iterator rowIterator() {
+ final int maxRows = numRows;
+ final ColumnarBatchRow row = new ColumnarBatchRow(columns);
+ return new Iterator() {
+ int rowId = 0;
+
+ @Override
+ public boolean hasNext() {
+ return rowId < maxRows;
+ }
+
+ @Override
+ public InternalRow next() {
+ if (rowId >= maxRows) {
+ throw new NoSuchElementException();
+ }
+ row.rowId = rowId++;
+ return row;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+
+ /**
+ * Sets the number of rows in this batch.
+ */
+ public void setNumRows(int numRows) {
+ this.numRows = numRows;
+ }
+
+ /**
+ * Returns the number of columns that make up this batch.
+ */
+ public int numCols() { return columns.length; }
+
+ /**
+ * Returns the number of rows for read, including filtered rows.
+ */
+ public int numRows() { return numRows; }
+
+ /**
+ * Returns the column at `ordinal`.
+ */
+ public ColumnVector column(int ordinal) { return columns[ordinal]; }
+
+ /**
+ * Returns the row in this batch at `rowId`. Returned row is reused across calls.
+ */
+ public InternalRow getRow(int rowId) {
+ assert(rowId >= 0 && rowId < numRows);
+ row.rowId = rowId;
+ return row;
+ }
+
+ public ColumnarBatch(ColumnVector[] columns) {
+ this.columns = columns;
+ this.row = new ColumnarBatchRow(columns);
+ }
+}
+
+/**
+ * An internal class, which wraps an array of {@link ColumnVector} and provides a row view.
+ */
+class ColumnarBatchRow extends InternalRow {
+ public int rowId;
+ private final ColumnVector[] columns;
+
+ ColumnarBatchRow(ColumnVector[] columns) {
+ this.columns = columns;
+ }
+
+ @Override
+ public int numFields() { return columns.length; }
+
+ @Override
+ public InternalRow copy() {
+ GenericInternalRow row = new GenericInternalRow(columns.length);
+ for (int i = 0; i < numFields(); i++) {
+ if (isNullAt(i)) {
+ row.setNullAt(i);
+ } else {
+ DataType dt = columns[i].dataType();
+ if (dt instanceof BooleanType) {
+ row.setBoolean(i, getBoolean(i));
+ } else if (dt instanceof ByteType) {
+ row.setByte(i, getByte(i));
+ } else if (dt instanceof ShortType) {
+ row.setShort(i, getShort(i));
+ } else if (dt instanceof IntegerType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof LongType) {
+ row.setLong(i, getLong(i));
+ } else if (dt instanceof FloatType) {
+ row.setFloat(i, getFloat(i));
+ } else if (dt instanceof DoubleType) {
+ row.setDouble(i, getDouble(i));
+ } else if (dt instanceof StringType) {
+ row.update(i, getUTF8String(i).copy());
+ } else if (dt instanceof BinaryType) {
+ row.update(i, getBinary(i));
+ } else if (dt instanceof DecimalType) {
+ DecimalType t = (DecimalType)dt;
+ row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision());
+ } else if (dt instanceof DateType) {
+ row.setInt(i, getInt(i));
+ } else if (dt instanceof TimestampType) {
+ row.setLong(i, getLong(i));
+ } else {
+ throw new RuntimeException("Not implemented. " + dt);
+ }
+ }
+ }
+ return row;
+ }
+
+ @Override
+ public boolean anyNull() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); }
+
+ @Override
+ public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); }
+
+ @Override
+ public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); }
+
+ @Override
+ public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); }
+
+ @Override
+ public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); }
+
+ @Override
+ public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); }
+
+ @Override
+ public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); }
+
+ @Override
+ public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return columns[ordinal].getDecimal(rowId, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return columns[ordinal].getUTF8String(rowId);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return columns[ordinal].getBinary(rowId);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return columns[ordinal].getInterval(rowId);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return columns[ordinal].getStruct(rowId);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return columns[ordinal].getArray(rowId);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return columns[ordinal].getMap(rowId);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ if (dataType instanceof BooleanType) {
+ return getBoolean(ordinal);
+ } else if (dataType instanceof ByteType) {
+ return getByte(ordinal);
+ } else if (dataType instanceof ShortType) {
+ return getShort(ordinal);
+ } else if (dataType instanceof IntegerType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof LongType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof FloatType) {
+ return getFloat(ordinal);
+ } else if (dataType instanceof DoubleType) {
+ return getDouble(ordinal);
+ } else if (dataType instanceof StringType) {
+ return getUTF8String(ordinal);
+ } else if (dataType instanceof BinaryType) {
+ return getBinary(ordinal);
+ } else if (dataType instanceof DecimalType) {
+ DecimalType t = (DecimalType) dataType;
+ return getDecimal(ordinal, t.precision(), t.scale());
+ } else if (dataType instanceof DateType) {
+ return getInt(ordinal);
+ } else if (dataType instanceof TimestampType) {
+ return getLong(ordinal);
+ } else if (dataType instanceof ArrayType) {
+ return getArray(ordinal);
+ } else if (dataType instanceof StructType) {
+ return getStruct(ordinal, ((StructType)dataType).fields().length);
+ } else if (dataType instanceof MapType) {
+ return getMap(ordinal);
+ } else {
+ throw new UnsupportedOperationException("Datatype not supported " + dataType);
+ }
+ }
+
+ @Override
+ public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); }
+
+ @Override
+ public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
similarity index 100%
rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala
similarity index 72%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala
index ac96c2765368f..86de1c9285b73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.streaming;
+package org.apache.spark.sql.catalog.v2
-/**
- * The shared interface between V1 and V2 streaming sinks.
- *
- * This is a temporary interface for compatibility during migration. It should not be implemented
- * directly, and will be removed in future versions.
- */
-public interface BaseStreamingSink {
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Experimental
+
+@Experimental
+class CatalogNotFoundException(message: String, cause: Throwable)
+ extends SparkException(message, cause) {
+
+ def this(message: String) = this(message, null)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala
new file mode 100644
index 0000000000000..f512cd5e23c6b
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform}
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Conversion helpers for working with v2 [[CatalogPlugin]].
+ */
+object CatalogV2Implicits {
+ implicit class PartitionTypeHelper(partitionType: StructType) {
+ def asTransforms: Array[Transform] = partitionType.names.map(LogicalExpressions.identity)
+ }
+
+ implicit class BucketSpecHelper(spec: BucketSpec) {
+ def asTransform: BucketTransform = {
+ if (spec.sortColumnNames.nonEmpty) {
+ throw new AnalysisException(
+ s"Cannot convert bucketing with sort columns to a transform: $spec")
+ }
+
+ LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*)
+ }
+ }
+
+ implicit class TransformHelper(transforms: Seq[Transform]) {
+ def asPartitionColumns: Seq[String] = {
+ val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform])
+
+ if (nonIdTransforms.nonEmpty) {
+ throw new AnalysisException("Transforms cannot be converted to partition columns: " +
+ nonIdTransforms.map(_.describe).mkString(", "))
+ }
+
+ idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref =>
+ val parts = ref.fieldNames
+ if (parts.size > 1) {
+ throw new AnalysisException(s"Cannot partition by nested column: $ref")
+ } else {
+ parts(0)
+ }
+ }
+ }
+ }
+
+ implicit class CatalogHelper(plugin: CatalogPlugin) {
+ def asTableCatalog: TableCatalog = plugin match {
+ case tableCatalog: TableCatalog =>
+ tableCatalog
+ case _ =>
+ throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a TableCatalog")
+ }
+ }
+
+ implicit class NamespaceHelper(namespace: Array[String]) {
+ def quoted: String = namespace.map(quote).mkString(".")
+ }
+
+ implicit class IdentifierHelper(ident: Identifier) {
+ def quoted: String = {
+ if (ident.namespace.nonEmpty) {
+ ident.namespace.map(quote).mkString(".") + "." + quote(ident.name)
+ } else {
+ quote(ident.name)
+ }
+ }
+ }
+
+ implicit class MultipartIdentifierHelper(namespace: Seq[String]) {
+ def quoted: String = namespace.map(quote).mkString(".")
+ }
+
+ private def quote(part: String): String = {
+ if (part.contains(".") || part.contains("`")) {
+ s"`${part.replace("`", "``")}`"
+ } else {
+ part
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala
new file mode 100644
index 0000000000000..5464a7496d23d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.TableIdentifier
+
+/**
+ * A trait to encapsulate catalog lookup function and helpful extractors.
+ */
+@Experimental
+trait LookupCatalog {
+
+ protected def lookupCatalog(name: String): CatalogPlugin
+
+ type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier)
+
+ /**
+ * Extract catalog plugin and identifier from a multi-part identifier.
+ */
+ object CatalogObjectIdentifier {
+ def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match {
+ case Seq(name) =>
+ Some((None, Identifier.of(Array.empty, name)))
+ case Seq(catalogName, tail @ _*) =>
+ try {
+ Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last)))
+ } catch {
+ case _: CatalogNotFoundException =>
+ Some((None, Identifier.of(parts.init.toArray, parts.last)))
+ }
+ }
+ }
+
+ /**
+ * Extract legacy table identifier from a multi-part identifier.
+ *
+ * For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths.
+ */
+ object AsTableIdentifier {
+ def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match {
+ case CatalogObjectIdentifier(None, ident) =>
+ ident.namespace match {
+ case Array() =>
+ Some(TableIdentifier(ident.name))
+ case Array(database) =>
+ Some(TableIdentifier(ident.name, Some(database)))
+ case _ =>
+ None
+ }
+ case _ =>
+ None
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala
new file mode 100644
index 0000000000000..2d4d6e7c6d5ee
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2.expressions
+
+import org.apache.spark.sql.catalyst
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
+
+/**
+ * Helper methods for working with the logical expressions API.
+ *
+ * Factory methods can be used when referencing the logical expression nodes is ambiguous because
+ * logical and internal expressions are used.
+ */
+private[sql] object LogicalExpressions {
+ // a generic parser that is only used for parsing multi-part field names.
+ // because this is only used for field names, the SQL conf passed in does not matter.
+ private lazy val parser = new CatalystSqlParser(SQLConf.get)
+
+ def literal[T](value: T): LiteralValue[T] = {
+ val internalLit = catalyst.expressions.Literal(value)
+ literal(value, internalLit.dataType)
+ }
+
+ def literal[T](value: T, dataType: DataType): LiteralValue[T] = LiteralValue(value, dataType)
+
+ def reference(name: String): NamedReference =
+ FieldReference(parser.parseMultipartIdentifier(name))
+
+ def apply(name: String, arguments: Expression*): Transform = ApplyTransform(name, arguments)
+
+ def bucket(numBuckets: Int, columns: String*): BucketTransform =
+ BucketTransform(literal(numBuckets, IntegerType), columns.map(reference))
+
+ def identity(column: String): IdentityTransform = IdentityTransform(reference(column))
+
+ def years(column: String): YearsTransform = YearsTransform(reference(column))
+
+ def months(column: String): MonthsTransform = MonthsTransform(reference(column))
+
+ def days(column: String): DaysTransform = DaysTransform(reference(column))
+
+ def hours(column: String): HoursTransform = HoursTransform(reference(column))
+}
+
+/**
+ * Base class for simple transforms of a single column.
+ */
+private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform {
+
+ def reference: NamedReference = ref
+
+ override def references: Array[NamedReference] = Array(ref)
+
+ override def arguments: Array[Expression] = Array(ref)
+
+ override def describe: String = name + "(" + reference.describe + ")"
+
+ override def toString: String = describe
+}
+
+private[sql] final case class BucketTransform(
+ numBuckets: Literal[Int],
+ columns: Seq[NamedReference]) extends Transform {
+
+ override val name: String = "bucket"
+
+ override def references: Array[NamedReference] = {
+ arguments
+ .filter(_.isInstanceOf[NamedReference])
+ .map(_.asInstanceOf[NamedReference])
+ }
+
+ override def arguments: Array[Expression] = numBuckets +: columns.toArray
+
+ override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"
+
+ override def toString: String = describe
+}
+
+private[sql] final case class ApplyTransform(
+ name: String,
+ args: Seq[Expression]) extends Transform {
+
+ override def arguments: Array[Expression] = args.toArray
+
+ override def references: Array[NamedReference] = {
+ arguments
+ .filter(_.isInstanceOf[NamedReference])
+ .map(_.asInstanceOf[NamedReference])
+ }
+
+ override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
+
+ override def toString: String = describe
+}
+
+private[sql] final case class IdentityTransform(
+ ref: NamedReference) extends SingleColumnTransform(ref) {
+ override val name: String = "identity"
+ override def describe: String = ref.describe
+}
+
+private[sql] final case class YearsTransform(
+ ref: NamedReference) extends SingleColumnTransform(ref) {
+ override val name: String = "years"
+}
+
+private[sql] final case class MonthsTransform(
+ ref: NamedReference) extends SingleColumnTransform(ref) {
+ override val name: String = "months"
+}
+
+private[sql] final case class DaysTransform(
+ ref: NamedReference) extends SingleColumnTransform(ref) {
+ override val name: String = "days"
+}
+
+private[sql] final case class HoursTransform(
+ ref: NamedReference) extends SingleColumnTransform(ref) {
+ override val name: String = "hours"
+}
+
+private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
+ override def describe: String = {
+ if (dataType.isInstanceOf[StringType]) {
+ s"'$value'"
+ } else {
+ s"$value"
+ }
+ }
+ override def toString: String = describe
+}
+
+private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference {
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.MultipartIdentifierHelper
+ override def fieldNames: Array[String] = parts.toArray
+ override def describe: String = parts.quoted
+ override def toString: String = describe
+}
+
+private[sql] object FieldReference {
+ def apply(column: String): NamedReference = {
+ LogicalExpressions.reference(column)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
index 6d587abd8fd4d..f5e9a146bf359 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+import org.apache.spark.sql.catalog.v2.Identifier
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
/**
@@ -25,13 +27,26 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
* as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
class DatabaseAlreadyExistsException(db: String)
- extends AnalysisException(s"Database '$db' already exists")
+ extends NamespaceAlreadyExistsException(s"Database '$db' already exists")
-class TableAlreadyExistsException(db: String, table: String)
- extends AnalysisException(s"Table or view '$table' already exists in database '$db'")
+class NamespaceAlreadyExistsException(message: String) extends AnalysisException(message) {
+ def this(namespace: Array[String]) = {
+ this(s"Namespace '${namespace.quoted}' already exists")
+ }
+}
+
+class TableAlreadyExistsException(message: String) extends AnalysisException(message) {
+ def this(db: String, table: String) = {
+ this(s"Table or view '$table' already exists in database '$db'")
+ }
+
+ def this(tableIdent: Identifier) = {
+ this(s"Table ${tableIdent.quoted} already exists")
+ }
+}
class TempTableAlreadyExistsException(table: String)
- extends AnalysisException(s"Temporary view '$table' already exists")
+ extends TableAlreadyExistsException(s"Temporary view '$table' already exists")
class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec)
extends AnalysisException(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a84bb7653c527..e0c0ad6efb483 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
@@ -96,12 +97,15 @@ class Analyzer(
catalog: SessionCatalog,
conf: SQLConf,
maxIterations: Int)
- extends RuleExecutor[LogicalPlan] with CheckAnalysis {
+ extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog {
def this(catalog: SessionCatalog, conf: SQLConf) = {
this(catalog, conf, conf.optimizerMaxIterations)
}
+ override protected def lookupCatalog(name: String): CatalogPlugin =
+ throw new CatalogNotFoundException("No catalog lookup function")
+
def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = {
AnalysisHelper.markInAnalyzer {
val analyzed = executeAndTrack(plan, tracker)
@@ -978,6 +982,11 @@ class Analyzer(
case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) =>
a.mapExpressions(resolveExpressionTopDown(_, appendColumns))
+ case o: OverwriteByExpression if !o.outputResolved =>
+ // do not resolve expression attributes until the query attributes are resolved against the
+ // table by ResolveOutputRelation. that rule will alias the attributes to the table's names.
+ o
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
@@ -2237,7 +2246,7 @@ class Analyzer(
object ResolveOutputRelation extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case append @ AppendData(table, query, isByName)
- if table.resolved && query.resolved && !append.resolved =>
+ if table.resolved && query.resolved && !append.outputResolved =>
val projection = resolveOutputColumns(table.name, table.output, query, isByName)
if (projection != query) {
@@ -2245,6 +2254,26 @@ class Analyzer(
} else {
append
}
+
+ case overwrite @ OverwriteByExpression(table, _, query, isByName)
+ if table.resolved && query.resolved && !overwrite.outputResolved =>
+ val projection = resolveOutputColumns(table.name, table.output, query, isByName)
+
+ if (projection != query) {
+ overwrite.copy(query = projection)
+ } else {
+ overwrite
+ }
+
+ case overwrite @ OverwritePartitionsDynamic(table, query, isByName)
+ if table.resolved && query.resolved && !overwrite.outputResolved =>
+ val projection = resolveOutputColumns(table.name, table.output, query, isByName)
+
+ if (projection != query) {
+ overwrite.copy(query = projection)
+ } else {
+ overwrite
+ }
}
def resolveOutputColumns(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 18c40b370cb5f..fcb2eec609c28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -33,6 +33,8 @@ import org.apache.spark.sql.types._
*/
trait CheckAnalysis extends PredicateHelper {
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+
/**
* Override to provide additional checks for correct analysis.
* These rules will be evaluated after our built-in check rules.
@@ -296,6 +298,21 @@ trait CheckAnalysis extends PredicateHelper {
}
}
+ case CreateTableAsSelect(_, _, partitioning, query, _, _, _) =>
+ val references = partitioning.flatMap(_.references).toSet
+ val badReferences = references.map(_.fieldNames).flatMap { column =>
+ query.schema.findNestedField(column) match {
+ case Some(_) =>
+ None
+ case _ =>
+ Some(s"${column.quoted} is missing or is in a map or array")
+ }
+ }
+
+ if (badReferences.nonEmpty) {
+ failAnalysis(s"Invalid partitioning: ${badReferences.mkString(", ")}")
+ }
+
case _ => // Fallbacks to the following checks
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala
index ad201f947b671..56b8d84441c95 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala
@@ -21,4 +21,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
trait NamedRelation extends LogicalPlan {
def name: String
+
+ // When false, the schema of input data must match the schema of this relation, during write.
+ def skipSchemaResolution: Boolean = false
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
index 8bf6f69f3b17a..7ac8ae61ed537 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+import org.apache.spark.sql.catalog.v2.Identifier
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -25,10 +27,24 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
* Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception
* as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
*/
-class NoSuchDatabaseException(val db: String) extends AnalysisException(s"Database '$db' not found")
+class NoSuchDatabaseException(
+ val db: String) extends NoSuchNamespaceException(s"Database '$db' not found")
-class NoSuchTableException(db: String, table: String)
- extends AnalysisException(s"Table or view '$table' not found in database '$db'")
+class NoSuchNamespaceException(message: String) extends AnalysisException(message) {
+ def this(namespace: Array[String]) = {
+ this(s"Namespace '${namespace.quoted}' not found")
+ }
+}
+
+class NoSuchTableException(message: String) extends AnalysisException(message) {
+ def this(db: String, table: String) = {
+ this(s"Table or view '$table' not found in database '$db'")
+ }
+
+ def this(tableIdent: Identifier) = {
+ this(s"Table ${tableIdent.quoted} not found")
+ }
+}
class NoSuchPartitionException(
db: String,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index a27c6d3c3671c..81ec2a1d9c904 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -29,14 +29,18 @@ import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalog.v2
+import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform}
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last}
import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
@@ -86,6 +90,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
visitFunctionIdentifier(ctx.functionIdentifier)
}
+ override def visitSingleMultipartIdentifier(
+ ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
+ visitMultipartIdentifier(ctx.multipartIdentifier)
+ }
+
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
visitSparkDataType(ctx.dataType)
}
@@ -117,6 +126,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
+ override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) {
+ plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)
+ }
+
/**
* Create a named logical plan.
*
@@ -953,6 +966,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
}
+ /**
+ * Create a multi-part identifier.
+ */
+ override def visitMultipartIdentifier(
+ ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) {
+ ctx.parts.asScala.map(_.getText)
+ }
+
/* ********************************************************************************************
* Expression parsing
* ******************************************************************************************** */
@@ -1851,4 +1872,301 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true)
if (STRING == null) structField else structField.withComment(string(STRING))
}
+
+ /**
+ * Create location string.
+ */
+ override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
+ string(ctx.STRING)
+ }
+
+ /**
+ * Create a [[BucketSpec]].
+ */
+ override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
+ BucketSpec(
+ ctx.INTEGER_VALUE.getText.toInt,
+ visitIdentifierList(ctx.identifierList),
+ Option(ctx.orderedIdentifierList)
+ .toSeq
+ .flatMap(_.orderedIdentifier.asScala)
+ .map { orderedIdCtx =>
+ Option(orderedIdCtx.ordering).map(_.getText).foreach { dir =>
+ if (dir.toLowerCase(Locale.ROOT) != "asc") {
+ operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx)
+ }
+ }
+
+ orderedIdCtx.identifier.getText
+ })
+ }
+
+ /**
+ * Convert a table property list into a key-value map.
+ * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
+ */
+ override def visitTablePropertyList(
+ ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
+ val properties = ctx.tableProperty.asScala.map { property =>
+ val key = visitTablePropertyKey(property.key)
+ val value = visitTablePropertyValue(property.value)
+ key -> value
+ }
+ // Check for duplicate property names.
+ checkDuplicateKeys(properties, ctx)
+ properties.toMap
+ }
+
+ /**
+ * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified.
+ */
+ def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.collect { case (key, null) => key }
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props
+ }
+
+ /**
+ * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified.
+ */
+ def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = {
+ val props = visitTablePropertyList(ctx)
+ val badKeys = props.filter { case (_, v) => v != null }.keys
+ if (badKeys.nonEmpty) {
+ operationNotAllowed(
+ s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
+ }
+ props.keys.toSeq
+ }
+
+ /**
+ * A table property key can either be String or a collection of dot separated elements. This
+ * function extracts the property key based on whether its a string literal or a table property
+ * identifier.
+ */
+ override def visitTablePropertyKey(key: TablePropertyKeyContext): String = {
+ if (key.STRING != null) {
+ string(key.STRING)
+ } else {
+ key.getText
+ }
+ }
+
+ /**
+ * A table property value can be String, Integer, Boolean or Decimal. This function extracts
+ * the property value based on whether its a string, integer, boolean or decimal literal.
+ */
+ override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
+ if (value == null) {
+ null
+ } else if (value.STRING != null) {
+ string(value.STRING)
+ } else if (value.booleanValue != null) {
+ value.getText.toLowerCase(Locale.ROOT)
+ } else {
+ value.getText
+ }
+ }
+
+ /**
+ * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
+ */
+ type TableHeader = (Seq[String], Boolean, Boolean, Boolean)
+
+ /**
+ * Validate a create table statement and return the [[TableIdentifier]].
+ */
+ override def visitCreateTableHeader(
+ ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
+ val temporary = ctx.TEMPORARY != null
+ val ifNotExists = ctx.EXISTS != null
+ if (temporary && ifNotExists) {
+ operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx)
+ }
+ val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText)
+ (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null)
+ }
+
+ /**
+ * Parse a list of transforms.
+ */
+ override def visitTransformList(ctx: TransformListContext): Seq[Transform] = withOrigin(ctx) {
+ def getFieldReference(
+ ctx: ApplyTransformContext,
+ arg: v2.expressions.Expression): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ arg match {
+ case ref: FieldReference =>
+ ref
+ case nonRef =>
+ throw new ParseException(
+ s"Expected a column reference for transform $name: ${nonRef.describe}", ctx)
+ }
+ }
+
+ def getSingleFieldReference(
+ ctx: ApplyTransformContext,
+ arguments: Seq[v2.expressions.Expression]): FieldReference = {
+ lazy val name: String = ctx.identifier.getText
+ if (arguments.size > 1) {
+ throw new ParseException(s"Too many arguments for transform $name", ctx)
+ } else if (arguments.isEmpty) {
+ throw new ParseException(s"Not enough arguments for transform $name", ctx)
+ } else {
+ getFieldReference(ctx, arguments.head)
+ }
+ }
+
+ ctx.transforms.asScala.map {
+ case identityCtx: IdentityTransformContext =>
+ IdentityTransform(FieldReference(
+ identityCtx.qualifiedName.identifier.asScala.map(_.getText)))
+
+ case applyCtx: ApplyTransformContext =>
+ val arguments = applyCtx.argument.asScala.map(visitTransformArgument)
+
+ applyCtx.identifier.getText match {
+ case "bucket" =>
+ val numBuckets: Int = arguments.head match {
+ case LiteralValue(shortValue, ShortType) =>
+ shortValue.asInstanceOf[Short].toInt
+ case LiteralValue(intValue, IntegerType) =>
+ intValue.asInstanceOf[Int]
+ case LiteralValue(longValue, LongType) =>
+ longValue.asInstanceOf[Long].toInt
+ case lit =>
+ throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx)
+ }
+
+ val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg))
+
+ BucketTransform(LiteralValue(numBuckets, IntegerType), fields)
+
+ case "years" =>
+ YearsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "months" =>
+ MonthsTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "days" =>
+ DaysTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case "hours" =>
+ HoursTransform(getSingleFieldReference(applyCtx, arguments))
+
+ case name =>
+ ApplyTransform(name, arguments)
+ }
+ }
+ }
+
+ /**
+ * Parse an argument to a transform. An argument may be a field reference (qualified name) or
+ * a value literal.
+ */
+ override def visitTransformArgument(ctx: TransformArgumentContext): v2.expressions.Expression = {
+ withOrigin(ctx) {
+ val reference = Option(ctx.qualifiedName)
+ .map(nameCtx => FieldReference(nameCtx.identifier.asScala.map(_.getText)))
+ val literal = Option(ctx.constant)
+ .map(typedVisit[Literal])
+ .map(lit => LiteralValue(lit.value, lit.dataType))
+ reference.orElse(literal)
+ .getOrElse(throw new ParseException(s"Invalid transform argument", ctx))
+ }
+ }
+
+ /**
+ * Create a table, returning a [[CreateTableStatement]] logical plan.
+ *
+ * Expected format:
+ * {{{
+ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
+ * USING table_provider
+ * create_table_clauses
+ * [[AS] select_statement];
+ *
+ * create_table_clauses (order insensitive):
+ * [OPTIONS table_property_list]
+ * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)]
+ * [CLUSTERED BY (col_name, col_name, ...)
+ * [SORTED BY (col_name [ASC|DESC], ...)]
+ * INTO num_buckets BUCKETS
+ * ]
+ * [LOCATION path]
+ * [COMMENT table_comment]
+ * [TBLPROPERTIES (property_name=property_value, ...)]
+ * }}}
+ */
+ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
+ val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ if (external) {
+ operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx)
+ }
+
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
+
+ val schema = Option(ctx.colTypeList()).map(createSchema)
+ val partitioning: Seq[Transform] =
+ Option(ctx.partitioning).map(visitTransformList).getOrElse(Nil)
+ val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
+ val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+
+ val provider = ctx.tableProvider.qualifiedName.getText
+ val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec)
+ val comment = Option(ctx.comment).map(string)
+
+ Option(ctx.query).map(plan) match {
+ case Some(_) if temp =>
+ operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx)
+
+ case Some(_) if schema.isDefined =>
+ operationNotAllowed(
+ "Schema may not be specified in a Create Table As Select (CTAS) statement",
+ ctx)
+
+ case Some(query) =>
+ CreateTableAsSelectStatement(
+ table, query, partitioning, bucketSpec, properties, provider, options, location, comment,
+ ifNotExists = ifNotExists)
+
+ case None if temp =>
+ // CREATE TEMPORARY TABLE ... USING ... is not supported by the catalyst parser.
+ // Use CREATE TEMPORARY VIEW ... USING ... instead.
+ operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx)
+
+ case _ =>
+ CreateTableStatement(table, schema.getOrElse(new StructType), partitioning, bucketSpec,
+ properties, provider, options, location, comment, ifNotExists = ifNotExists)
+ }
+ }
+
+ /**
+ * Create a [[DropTableStatement]] command.
+ */
+ override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) {
+ DropTableStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier()),
+ ctx.EXISTS != null,
+ ctx.PURGE != null)
+ }
+
+ /**
+ * Create a [[DropViewStatement]] command.
+ */
+ override def visitDropView(ctx: DropViewContext): AnyRef = withOrigin(ctx) {
+ DropViewStatement(
+ visitMultipartIdentifier(ctx.multipartIdentifier()),
+ ctx.EXISTS != null)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
index 2128a10d0b1bc..31917ab9a5579 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
@@ -57,6 +57,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
}
}
+ /** Creates a multi-part identifier for a given SQL string */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ parse(sqlText) { parser =>
+ astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
+ }
+ }
+
/**
* Creates StructType for a given SQL string, which is a comma separated list of field
* definitions which will preserve the correct Hive metadata.
@@ -85,6 +92,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
lexer.removeErrorListeners()
lexer.addErrorListener(ParseErrorListener)
lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
+ lexer.ansi = SQLConf.get.ansiParserEnabled
val tokenStream = new CommonTokenStream(lexer)
val parser = new SqlBaseParser(tokenStream)
@@ -92,6 +100,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
parser.removeErrorListeners()
parser.addErrorListener(ParseErrorListener)
parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced
+ parser.ansi = SQLConf.get.ansiParserEnabled
try {
try {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index 75240d2196222..77e357ad073da 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -52,6 +52,12 @@ trait ParserInterface {
@throws[ParseException]("Text cannot be parsed to a FunctionIdentifier")
def parseFunctionIdentifier(sqlText: String): FunctionIdentifier
+ /**
+ * Parse a string to a multi-part identifier.
+ */
+ @throws[ParseException]("Text cannot be parsed to a multi-part identifier")
+ def parseMultipartIdentifier(sqlText: String): Seq[String]
+
/**
* Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list
* of field definitions which will preserve the correct Hive metadata.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 639d68f4ecd76..256d3261055e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable}
@@ -365,37 +367,132 @@ case class Join(
}
/**
- * Append data to an existing table.
+ * Base trait for DataSourceV2 write commands
*/
-case class AppendData(
- table: NamedRelation,
- query: LogicalPlan,
- isByName: Boolean) extends LogicalPlan {
+trait V2WriteCommand extends Command {
+ def table: NamedRelation
+ def query: LogicalPlan
+
override def children: Seq[LogicalPlan] = Seq(query)
- override def output: Seq[Attribute] = Seq.empty
- override lazy val resolved: Boolean = {
- table.resolved && query.resolved && query.output.size == table.output.size &&
+ override lazy val resolved: Boolean = outputResolved
+
+ def outputResolved: Boolean = {
+ // If the table doesn't require schema match, we don't need to resolve the output columns.
+ table.skipSchemaResolution || {
+ table.resolved && query.resolved && query.output.size == table.output.size &&
query.output.zip(table.output).forall {
case (inAttr, outAttr) =>
// names and types must match, nullability must be compatible
inAttr.name == outAttr.name &&
- DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) &&
- (outAttr.nullable || !inAttr.nullable)
+ DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) &&
+ (outAttr.nullable || !inAttr.nullable)
}
+ }
}
}
+/**
+ * Create a new table with a v2 catalog.
+ */
+case class CreateV2Table(
+ catalog: TableCatalog,
+ tableName: Identifier,
+ tableSchema: StructType,
+ partitioning: Seq[Transform],
+ properties: Map[String, String],
+ ignoreIfExists: Boolean) extends Command
+
+/**
+ * Create a new table from a select query with a v2 catalog.
+ */
+case class CreateTableAsSelect(
+ catalog: TableCatalog,
+ tableName: Identifier,
+ partitioning: Seq[Transform],
+ query: LogicalPlan,
+ properties: Map[String, String],
+ writeOptions: Map[String, String],
+ ignoreIfExists: Boolean) extends Command {
+
+ override def children: Seq[LogicalPlan] = Seq(query)
+
+ override lazy val resolved: Boolean = {
+ // the table schema is created from the query schema, so the only resolution needed is to check
+ // that the columns referenced by the table's partitioning exist in the query schema
+ val references = partitioning.flatMap(_.references).toSet
+ references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined)
+ }
+}
+
+/**
+ * Append data to an existing table.
+ */
+case class AppendData(
+ table: NamedRelation,
+ query: LogicalPlan,
+ isByName: Boolean) extends V2WriteCommand
+
object AppendData {
def byName(table: NamedRelation, df: LogicalPlan): AppendData = {
- new AppendData(table, df, true)
+ new AppendData(table, df, isByName = true)
}
def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = {
- new AppendData(table, query, false)
+ new AppendData(table, query, isByName = false)
+ }
+}
+
+/**
+ * Overwrite data matching a filter in an existing table.
+ */
+case class OverwriteByExpression(
+ table: NamedRelation,
+ deleteExpr: Expression,
+ query: LogicalPlan,
+ isByName: Boolean) extends V2WriteCommand {
+ override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved
+}
+
+object OverwriteByExpression {
+ def byName(
+ table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
+ OverwriteByExpression(table, deleteExpr, df, isByName = true)
+ }
+
+ def byPosition(
+ table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = {
+ OverwriteByExpression(table, deleteExpr, query, isByName = false)
}
}
+/**
+ * Dynamically overwrite partitions in an existing table.
+ */
+case class OverwritePartitionsDynamic(
+ table: NamedRelation,
+ query: LogicalPlan,
+ isByName: Boolean) extends V2WriteCommand
+
+object OverwritePartitionsDynamic {
+ def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = {
+ OverwritePartitionsDynamic(table, df, isByName = true)
+ }
+
+ def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = {
+ OverwritePartitionsDynamic(table, query, isByName = false)
+ }
+}
+
+/**
+ * Drop a table.
+ */
+case class DropTable(
+ catalog: TableCatalog,
+ ident: Identifier,
+ ifExists: Boolean) extends Command
+
+
/**
* Insert some data into a table. Note that this plan is unresolved and has to be replaced by the
* concrete implementations during analysis.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala
new file mode 100644
index 0000000000000..7a26e01cde830
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.sql
+
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A CREATE TABLE command, as parsed from SQL.
+ *
+ * This is a metadata-only command and is not used to write data to the created table.
+ */
+case class CreateTableStatement(
+ tableName: Seq[String],
+ tableSchema: StructType,
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ ifNotExists: Boolean) extends ParsedStatement {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = Seq.empty
+}
+
+/**
+ * A CREATE TABLE AS SELECT command, as parsed from SQL.
+ */
+case class CreateTableAsSelectStatement(
+ tableName: Seq[String],
+ asSelect: LogicalPlan,
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ ifNotExists: Boolean) extends ParsedStatement {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = Seq(asSelect)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala
new file mode 100644
index 0000000000000..d41e8a5010257
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.sql
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * A DROP TABLE statement, as parsed from SQL.
+ */
+case class DropTableStatement(
+ tableName: Seq[String],
+ ifExists: Boolean,
+ purge: Boolean) extends ParsedStatement {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = Seq.empty
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala
new file mode 100644
index 0000000000000..523158788e834
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.sql
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * A DROP VIEW statement, as parsed from SQL.
+ */
+case class DropViewStatement(
+ viewName: Seq[String],
+ ifExists: Boolean) extends ParsedStatement {
+
+ override def output: Seq[Attribute] = Seq.empty
+
+ override def children: Seq[LogicalPlan] = Seq.empty
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala
new file mode 100644
index 0000000000000..510f2a1ba1e6d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical.sql
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+/**
+ * A logical plan node that contains exactly what was parsed from SQL.
+ *
+ * This is used to hold information parsed from SQL when there are multiple implementations of a
+ * query or command. For example, CREATE TABLE may be implemented by different nodes for v1 and v2.
+ * Instead of parsing directly to a v1 CreateTable that keeps metadata in CatalogTable, and then
+ * converting that v1 metadata to the v2 equivalent, the sql [[CreateTableStatement]] plan is
+ * produced by the parser and converted once into both implementations.
+ *
+ * Parsed logical plans are not resolved because they must be converted to concrete logical plans.
+ *
+ * Parsed logical plans are located in Catalyst so that as much SQL parsing logic as possible is be
+ * kept in a [[org.apache.spark.sql.catalyst.parser.AbstractSqlParser]].
+ */
+private[sql] abstract class ParsedStatement extends LogicalPlan {
+ // Redact properties and options when parsed nodes are used by generic methods like toString
+ override def productIterator: Iterator[Any] = super.productIterator.map {
+ case mapArg: Map[_, _] => conf.redactOptions(mapArg.asInstanceOf[Map[String, String]])
+ case other => other
+ }
+
+ final override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
index f590c63f80b21..a85cad35ac6fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util
import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
+import java.time.ZoneId
import java.util.{Calendar, Locale, TimeZone}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.{Function => JFunction}
@@ -123,6 +124,8 @@ object DateTimeUtils {
override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId)
}
+ def getZoneId(timeZoneId: String): ZoneId = ZoneId.of(timeZoneId, ZoneId.SHORT_IDS)
+
def getTimeZone(timeZoneId: String): TimeZone = {
computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 448338f61346f..cbc57066163c1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -314,6 +314,12 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ANSI_SQL_PARSER =
+ buildConf("spark.sql.parser.ansi.enabled")
+ .doc("When true, tries to conform to ANSI SQL syntax.")
+ .booleanConf
+ .createWithDefault(false)
+
val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
.doc("When true, string literals (including regex patterns) remain escaped in our SQL " +
@@ -918,6 +924,12 @@ object SQLConf {
.stringConf
.createOptional
+ val FORCE_DELETE_TEMP_CHECKPOINT_LOCATION =
+ buildConf("spark.sql.streaming.forceDeleteTempCheckpointLocation")
+ .doc("When true, enable temporary checkpoint locations force delete.")
+ .booleanConf
+ .createWithDefault(false)
+
val MIN_BATCHES_TO_RETAIN = buildConf("spark.sql.streaming.minBatchesToRetain")
.internal()
.doc("The minimum number of batches that must be retained and made recoverable.")
@@ -1117,6 +1129,14 @@ object SQLConf {
.internal()
.stringConf
+ val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED =
+ buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled")
+ .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " +
+ "to SPARK-26824.")
+ .internal()
+ .booleanConf
+ .createWithDefault(true)
+
val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION =
buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled")
.internal()
@@ -1427,6 +1447,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE =
+ buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize")
+ .doc("The max number of entries to be stored in queue to wait for late epochs. " +
+ "If this parameter is exceeded by the size of the queue, stream will stop with an error.")
+ .intConf
+ .createWithDefault(10000)
+
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
@@ -1457,7 +1484,7 @@ object SQLConf {
" register class names for which data source V2 write paths are disabled. Writes from these" +
" sources will fall back to the V1 sources.")
.stringConf
- .createWithDefault("")
+ .createWithDefault("orc")
val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers")
.doc("A comma-separated list of fully qualified data source register class names for which" +
@@ -1673,6 +1700,11 @@ object SQLConf {
"a SparkConf entry.")
.booleanConf
.createWithDefault(true)
+
+ val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog")
+ .doc("Name of the default v2 catalog, used when a catalog is not identified in queries")
+ .stringConf
+ .createOptional
}
/**
@@ -1848,6 +1880,8 @@ class SQLConf extends Serializable with Logging {
def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED)
+ def ansiParserEnabled: Boolean = getConf(ANSI_SQL_PARSER)
+
def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)
def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR)
@@ -2059,14 +2093,17 @@ class SQLConf extends Serializable with Logging {
def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION)
+ def continuousStreamingEpochBacklogQueueSize: Int =
+ getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE)
+
def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE)
def continuousStreamingExecutorPollIntervalMs: Long =
getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)
- def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST)
+ def useV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST)
- def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST)
+ def useV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST)
def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS)
@@ -2118,6 +2155,8 @@ class SQLConf extends Serializable with Logging {
def setCommandRejectsSparkCoreConfs: Boolean =
getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
+ def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG)
+
/** ********************** SQLConf functionality methods ************ */
/** Set Spark SQL configuration properties. */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
similarity index 92%
rename from sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 3f941cc6e1072..a1ab55a7185ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.sources
-import org.apache.spark.annotation.Stable
+import org.apache.spark.annotation.{Evolving, Stable}
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the filters that we can push down to the data sources.
@@ -218,3 +218,27 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
case class StringContains(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
}
+
+/**
+ * A filter that always evaluates to `true`.
+ */
+@Evolving
+case class AlwaysTrue() extends Filter {
+ override def references: Array[String] = Array.empty
+}
+
+@Evolving
+object AlwaysTrue extends AlwaysTrue {
+}
+
+/**
+ * A filter that always evaluates to `false`.
+ */
+@Evolving
+case class AlwaysFalse() extends Filter {
+ override def references: Array[String] = Array.empty
+}
+
+@Evolving
+object AlwaysFalse extends AlwaysFalse {
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index d563276a5711d..c472bd8ee84b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -307,6 +307,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
nameToIndex.get(name)
}
+ /**
+ * Returns a field in this struct and its child structs.
+ *
+ * This does not support finding fields nested in maps or arrays.
+ */
+ private[sql] def findNestedField(fieldNames: Seq[String]): Option[StructField] = {
+ fieldNames.headOption.flatMap(nameToField.get) match {
+ case Some(field) =>
+ if (fieldNames.tail.isEmpty) {
+ Some(field)
+ } else {
+ field.dataType match {
+ case struct: StructType =>
+ struct.findNestedField(fieldNames.tail)
+ case _ =>
+ None
+ }
+ }
+ case _ =>
+ None
+ }
+ }
+
protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
similarity index 99%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
rename to sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 7de6256aef084..62546a322d3c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.arrow
+package org.apache.spark.sql.util
import scala.collection.JavaConverters._
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java
new file mode 100644
index 0000000000000..326b12f3618d3
--- /dev/null
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2;
+
+import org.apache.spark.SparkException;
+import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.concurrent.Callable;
+
+public class CatalogLoadingSuite {
+ @Test
+ public void testLoad() throws SparkException {
+ SQLConf conf = new SQLConf();
+ conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName());
+
+ CatalogPlugin plugin = Catalogs.load("test-name", conf);
+ Assert.assertNotNull("Should instantiate a non-null plugin", plugin);
+ Assert.assertEquals("Plugin should have correct implementation",
+ TestCatalogPlugin.class, plugin.getClass());
+
+ TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin;
+ Assert.assertEquals("Options should contain no keys", 0, testPlugin.options.size());
+ Assert.assertEquals("Catalog should have correct name", "test-name", testPlugin.name());
+ }
+
+ @Test
+ public void testInitializationOptions() throws SparkException {
+ SQLConf conf = new SQLConf();
+ conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName());
+ conf.setConfString("spark.sql.catalog.test-name.name", "not-catalog-name");
+ conf.setConfString("spark.sql.catalog.test-name.kEy", "valUE");
+
+ CatalogPlugin plugin = Catalogs.load("test-name", conf);
+ Assert.assertNotNull("Should instantiate a non-null plugin", plugin);
+ Assert.assertEquals("Plugin should have correct implementation",
+ TestCatalogPlugin.class, plugin.getClass());
+
+ TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin;
+
+ Assert.assertEquals("Options should contain only two keys", 2, testPlugin.options.size());
+ Assert.assertEquals("Options should contain correct value for name (not overwritten)",
+ "not-catalog-name", testPlugin.options.get("name"));
+ Assert.assertEquals("Options should contain correct value for key",
+ "valUE", testPlugin.options.get("key"));
+ }
+
+ @Test
+ public void testLoadWithoutConfig() {
+ SQLConf conf = new SQLConf();
+
+ SparkException exc = intercept(CatalogNotFoundException.class,
+ () -> Catalogs.load("missing", conf));
+
+ Assert.assertTrue("Should complain that implementation is not configured",
+ exc.getMessage()
+ .contains("plugin class not found: spark.sql.catalog.missing is not defined"));
+ Assert.assertTrue("Should identify the catalog by name",
+ exc.getMessage().contains("missing"));
+ }
+
+ @Test
+ public void testLoadMissingClass() {
+ SQLConf conf = new SQLConf();
+ conf.setConfString("spark.sql.catalog.missing", "com.example.NoSuchCatalogPlugin");
+
+ SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf));
+
+ Assert.assertTrue("Should complain that the class is not found",
+ exc.getMessage().contains("Cannot find catalog plugin class"));
+ Assert.assertTrue("Should identify the catalog by name",
+ exc.getMessage().contains("missing"));
+ Assert.assertTrue("Should identify the missing class",
+ exc.getMessage().contains("com.example.NoSuchCatalogPlugin"));
+ }
+
+ @Test
+ public void testLoadNonCatalogPlugin() {
+ SQLConf conf = new SQLConf();
+ String invalidClassName = InvalidCatalogPlugin.class.getCanonicalName();
+ conf.setConfString("spark.sql.catalog.invalid", invalidClassName);
+
+ SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf));
+
+ Assert.assertTrue("Should complain that class does not implement CatalogPlugin",
+ exc.getMessage().contains("does not implement CatalogPlugin"));
+ Assert.assertTrue("Should identify the catalog by name",
+ exc.getMessage().contains("invalid"));
+ Assert.assertTrue("Should identify the class",
+ exc.getMessage().contains(invalidClassName));
+ }
+
+ @Test
+ public void testLoadConstructorFailureCatalogPlugin() {
+ SQLConf conf = new SQLConf();
+ String invalidClassName = ConstructorFailureCatalogPlugin.class.getCanonicalName();
+ conf.setConfString("spark.sql.catalog.invalid", invalidClassName);
+
+ RuntimeException exc = intercept(RuntimeException.class, () -> Catalogs.load("invalid", conf));
+
+ Assert.assertTrue("Should have expected error message",
+ exc.getMessage().contains("Expected failure"));
+ }
+
+ @Test
+ public void testLoadAccessErrorCatalogPlugin() {
+ SQLConf conf = new SQLConf();
+ String invalidClassName = AccessErrorCatalogPlugin.class.getCanonicalName();
+ conf.setConfString("spark.sql.catalog.invalid", invalidClassName);
+
+ SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf));
+
+ Assert.assertTrue("Should complain that no public constructor is provided",
+ exc.getMessage().contains("Failed to call public no-arg constructor for catalog"));
+ Assert.assertTrue("Should identify the catalog by name",
+ exc.getMessage().contains("invalid"));
+ Assert.assertTrue("Should identify the class",
+ exc.getMessage().contains(invalidClassName));
+ }
+
+ @SuppressWarnings("unchecked")
+ public static E intercept(Class expected, Callable> callable) {
+ try {
+ callable.call();
+ Assert.fail("No exception was thrown, expected: " +
+ expected.getName());
+ } catch (Exception actual) {
+ try {
+ Assert.assertEquals(expected, actual.getClass());
+ return (E) actual;
+ } catch (AssertionError e) {
+ e.addSuppressed(actual);
+ throw e;
+ }
+ }
+ // Compiler doesn't catch that Assert.fail will always throw an exception.
+ throw new UnsupportedOperationException("[BUG] Should not reach this statement");
+ }
+}
+
+class TestCatalogPlugin implements CatalogPlugin {
+ String name = null;
+ CaseInsensitiveStringMap options = null;
+
+ TestCatalogPlugin() {
+ }
+
+ @Override
+ public void initialize(String name, CaseInsensitiveStringMap options) {
+ this.name = name;
+ this.options = options;
+ }
+
+ @Override
+ public String name() {
+ return name;
+ }
+}
+
+class ConstructorFailureCatalogPlugin implements CatalogPlugin { // fails in its constructor
+ ConstructorFailureCatalogPlugin() {
+ throw new RuntimeException("Expected failure.");
+ }
+
+ @Override
+ public void initialize(String name, CaseInsensitiveStringMap options) {
+ }
+
+ @Override
+ public String name() {
+ return null;
+ }
+}
+
+class AccessErrorCatalogPlugin implements CatalogPlugin { // no public constructor
+ private AccessErrorCatalogPlugin() {
+ }
+
+ @Override
+ public void initialize(String name, CaseInsensitiveStringMap options) {
+ }
+
+ @Override
+ public String name() {
+ return null;
+ }
+}
+
+class InvalidCatalogPlugin { // doesn't implement CatalogPlugin
+ public void initialize(CaseInsensitiveStringMap options) {
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala
new file mode 100644
index 0000000000000..9c1b9a3e53de2
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala
@@ -0,0 +1,657 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2
+
+import java.util
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class TableCatalogSuite extends SparkFunSuite {
+ import CatalogV2Implicits._
+
+ private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
+ private val schema: StructType = new StructType()
+ .add("id", IntegerType)
+ .add("data", StringType)
+
+ private def newCatalog(): TableCatalog = {
+ val newCatalog = new TestTableCatalog
+ newCatalog.initialize("test", CaseInsensitiveStringMap.empty())
+ newCatalog
+ }
+
+ private val testIdent = Identifier.of(Array("`", "."), "test_table")
+
+ test("Catalogs can load the catalog") {
+ val catalog = newCatalog()
+
+ val conf = new SQLConf
+ conf.setConfString("spark.sql.catalog.test", catalog.getClass.getName)
+
+ val loaded = Catalogs.load("test", conf)
+ assert(loaded.getClass == catalog.getClass)
+ }
+
+ test("listTables") {
+ val catalog = newCatalog()
+ val ident1 = Identifier.of(Array("ns"), "test_table_1")
+ val ident2 = Identifier.of(Array("ns"), "test_table_2")
+ val ident3 = Identifier.of(Array("ns2"), "test_table_1")
+
+ assert(catalog.listTables(Array("ns")).isEmpty)
+
+ catalog.createTable(ident1, schema, Array.empty, emptyProps)
+
+ assert(catalog.listTables(Array("ns")).toSet == Set(ident1))
+ assert(catalog.listTables(Array("ns2")).isEmpty)
+
+ catalog.createTable(ident3, schema, Array.empty, emptyProps)
+ catalog.createTable(ident2, schema, Array.empty, emptyProps)
+
+ assert(catalog.listTables(Array("ns")).toSet == Set(ident1, ident2))
+ assert(catalog.listTables(Array("ns2")).toSet == Set(ident3))
+
+ catalog.dropTable(ident1)
+
+ assert(catalog.listTables(Array("ns")).toSet == Set(ident2))
+
+ catalog.dropTable(ident2)
+
+ assert(catalog.listTables(Array("ns")).isEmpty)
+ assert(catalog.listTables(Array("ns2")).toSet == Set(ident3))
+ }
+
+ test("createTable") {
+ val catalog = newCatalog()
+
+ assert(!catalog.tableExists(testIdent))
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
+ assert(parsed == Seq("`", ".", "test_table"))
+ assert(table.schema == schema)
+ assert(table.properties.asScala == Map())
+
+ assert(catalog.tableExists(testIdent))
+ }
+
+ test("createTable: with properties") {
+ val catalog = newCatalog()
+
+ val properties = new util.HashMap[String, String]()
+ properties.put("property", "value")
+
+ assert(!catalog.tableExists(testIdent))
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, properties)
+
+ val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name)
+ assert(parsed == Seq("`", ".", "test_table"))
+ assert(table.schema == schema)
+ assert(table.properties == properties)
+
+ assert(catalog.tableExists(testIdent))
+ }
+
+ test("createTable: table already exists") {
+ val catalog = newCatalog()
+
+ assert(!catalog.tableExists(testIdent))
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ val exc = intercept[TableAlreadyExistsException] {
+ catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+ }
+
+ assert(exc.message.contains(table.name()))
+ assert(exc.message.contains("already exists"))
+
+ assert(catalog.tableExists(testIdent))
+ }
+
+ test("tableExists") {
+ val catalog = newCatalog()
+
+ assert(!catalog.tableExists(testIdent))
+
+ catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(catalog.tableExists(testIdent))
+
+ catalog.dropTable(testIdent)
+
+ assert(!catalog.tableExists(testIdent))
+ }
+
+ test("loadTable") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+ val loaded = catalog.loadTable(testIdent)
+
+ assert(table.name == loaded.name)
+ assert(table.schema == loaded.schema)
+ assert(table.properties == loaded.properties)
+ }
+
+ test("loadTable: table does not exist") {
+ val catalog = newCatalog()
+
+ val exc = intercept[NoSuchTableException] {
+ catalog.loadTable(testIdent)
+ }
+
+ assert(exc.message.contains(testIdent.quoted))
+ assert(exc.message.contains("not found"))
+ }
+
+ test("invalidateTable") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+ catalog.invalidateTable(testIdent)
+
+ val loaded = catalog.loadTable(testIdent)
+
+ assert(table.name == loaded.name)
+ assert(table.schema == loaded.schema)
+ assert(table.properties == loaded.properties)
+ }
+
+ test("invalidateTable: table does not exist") {
+ val catalog = newCatalog()
+
+ assert(catalog.tableExists(testIdent) === false)
+
+ catalog.invalidateTable(testIdent)
+ }
+
+ test("alterTable: add property") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.properties.asScala == Map())
+
+ val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1"))
+ assert(updated.properties.asScala == Map("prop-1" -> "1"))
+
+ val loaded = catalog.loadTable(testIdent)
+ assert(loaded.properties.asScala == Map("prop-1" -> "1"))
+
+ assert(table.properties.asScala == Map())
+ }
+
+ test("alterTable: add property to existing") {
+ val catalog = newCatalog()
+
+ val properties = new util.HashMap[String, String]()
+ properties.put("prop-1", "1")
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, properties)
+
+ assert(table.properties.asScala == Map("prop-1" -> "1"))
+
+ val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2"))
+ assert(updated.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2"))
+
+ val loaded = catalog.loadTable(testIdent)
+ assert(loaded.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2"))
+
+ assert(table.properties.asScala == Map("prop-1" -> "1"))
+ }
+
+ test("alterTable: remove existing property") {
+ val catalog = newCatalog()
+
+ val properties = new util.HashMap[String, String]()
+ properties.put("prop-1", "1")
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, properties)
+
+ assert(table.properties.asScala == Map("prop-1" -> "1"))
+
+ val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1"))
+ assert(updated.properties.asScala == Map())
+
+ val loaded = catalog.loadTable(testIdent)
+ assert(loaded.properties.asScala == Map())
+
+ assert(table.properties.asScala == Map("prop-1" -> "1"))
+ }
+
+ test("alterTable: remove missing property") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.properties.asScala == Map())
+
+ val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1"))
+ assert(updated.properties.asScala == Map())
+
+ val loaded = catalog.loadTable(testIdent)
+ assert(loaded.properties.asScala == Map())
+
+ assert(table.properties.asScala == Map())
+ }
+
+ test("alterTable: add top-level column") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType))
+
+ assert(updated.schema == schema.add("ts", TimestampType))
+ }
+
+ test("alterTable: add required column") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.addColumn(Array("ts"), TimestampType, false))
+
+ assert(updated.schema == schema.add("ts", TimestampType, nullable = false))
+ }
+
+ test("alterTable: add column with comment") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.addColumn(Array("ts"), TimestampType, false, "comment text"))
+
+ val field = StructField("ts", TimestampType, nullable = false).withComment("comment text")
+ assert(updated.schema == schema.add(field))
+ }
+
+ test("alterTable: add nested column") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.addColumn(Array("point", "z"), DoubleType))
+
+ val expectedSchema = schema.add("point", pointStruct.add("z", DoubleType))
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: add column to primitive field fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent, TableChange.addColumn(Array("data", "ts"), TimestampType))
+ }
+
+ assert(exc.getMessage.contains("Not a struct"))
+ assert(exc.getMessage.contains("data"))
+
+ // the table has not changed
+ assert(catalog.loadTable(testIdent).schema == schema)
+ }
+
+ test("alterTable: add field to missing column fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent,
+ TableChange.addColumn(Array("missing_col", "new_field"), StringType))
+ }
+
+ assert(exc.getMessage.contains("missing_col"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: update column data type") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType))
+
+ val expectedSchema = new StructType().add("id", LongType).add("data", StringType)
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: update column data type and nullability") {
+ val catalog = newCatalog()
+
+ val originalSchema = new StructType()
+ .add("id", IntegerType, nullable = false)
+ .add("data", StringType)
+ val table = catalog.createTable(testIdent, originalSchema, Array.empty, emptyProps)
+
+ assert(table.schema == originalSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.updateColumnType(Array("id"), LongType, true))
+
+ val expectedSchema = new StructType().add("id", LongType).add("data", StringType)
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: update optional column to required fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType, false))
+ }
+
+ assert(exc.getMessage.contains("Cannot change optional column to required"))
+ assert(exc.getMessage.contains("id"))
+ }
+
+ test("alterTable: update missing column fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent,
+ TableChange.updateColumnType(Array("missing_col"), LongType))
+ }
+
+ assert(exc.getMessage.contains("missing_col"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: add comment") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.updateColumnComment(Array("id"), "comment text"))
+
+ val expectedSchema = new StructType()
+ .add("id", IntegerType, nullable = true, "comment text")
+ .add("data", StringType)
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: replace comment") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text"))
+
+ val expectedSchema = new StructType()
+ .add("id", IntegerType, nullable = true, "replacement comment")
+ .add("data", StringType)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.updateColumnComment(Array("id"), "replacement comment"))
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: add comment to missing column fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent,
+ TableChange.updateColumnComment(Array("missing_col"), "comment"))
+ }
+
+ assert(exc.getMessage.contains("missing_col"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: rename top-level column") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id"))
+
+ val expectedSchema = new StructType().add("some_id", IntegerType).add("data", StringType)
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: rename nested column") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.renameColumn(Array("point", "x"), "first"))
+
+ val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType)
+ val expectedSchema = schema.add("point", newPointStruct)
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: rename struct column") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.renameColumn(Array("point"), "p"))
+
+ val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val expectedSchema = schema.add("p", newPointStruct)
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: rename missing column fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent,
+ TableChange.renameColumn(Array("missing_col"), "new_name"))
+ }
+
+ assert(exc.getMessage.contains("missing_col"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: multiple changes") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.renameColumn(Array("point", "x"), "first"),
+ TableChange.renameColumn(Array("point", "y"), "second"))
+
+ val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType)
+ val expectedSchema = schema.add("point", newPointStruct)
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: delete top-level column") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.deleteColumn(Array("id")))
+
+ val expectedSchema = new StructType().add("data", StringType)
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: delete nested column") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val updated = catalog.alterTable(testIdent,
+ TableChange.deleteColumn(Array("point", "y")))
+
+ val newPointStruct = new StructType().add("x", DoubleType)
+ val expectedSchema = schema.add("point", newPointStruct)
+
+ assert(updated.schema == expectedSchema)
+ }
+
+ test("alterTable: delete missing column fails") {
+ val catalog = newCatalog()
+
+ val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(table.schema == schema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col")))
+ }
+
+ assert(exc.getMessage.contains("missing_col"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: delete missing nested column fails") {
+ val catalog = newCatalog()
+
+ val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType)
+ val tableSchema = schema.add("point", pointStruct)
+
+ val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps)
+
+ assert(table.schema == tableSchema)
+
+ val exc = intercept[IllegalArgumentException] {
+ catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z")))
+ }
+
+ assert(exc.getMessage.contains("z"))
+ assert(exc.getMessage.contains("Cannot find"))
+ }
+
+ test("alterTable: table does not exist") {
+ val catalog = newCatalog()
+
+ val exc = intercept[NoSuchTableException] {
+ catalog.alterTable(testIdent, TableChange.setProperty("prop", "val"))
+ }
+
+ assert(exc.message.contains(testIdent.quoted))
+ assert(exc.message.contains("not found"))
+ }
+
+ test("dropTable") {
+ val catalog = newCatalog()
+
+ assert(!catalog.tableExists(testIdent))
+
+ catalog.createTable(testIdent, schema, Array.empty, emptyProps)
+
+ assert(catalog.tableExists(testIdent))
+
+ val wasDropped = catalog.dropTable(testIdent)
+
+ assert(wasDropped)
+ assert(!catalog.tableExists(testIdent))
+ }
+
+ test("dropTable: table does not exist") {
+ val catalog = newCatalog()
+
+ assert(!catalog.tableExists(testIdent))
+
+ val wasDropped = catalog.dropTable(testIdent)
+
+ assert(!wasDropped)
+ assert(!catalog.tableExists(testIdent))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala
new file mode 100644
index 0000000000000..78b4763484cc0
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalog.v2
+
+import java.util
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RemoveProperty, RenameColumn, SetProperty, UpdateColumnComment, UpdateColumnType}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.sources.v2.{Table, TableCapability}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class TestTableCatalog extends TableCatalog {
+ import CatalogV2Implicits._
+
+ private val tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]()
+ private var _name: Option[String] = None
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
+ _name = Some(name)
+ }
+
+ override def name: String = _name.get
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+ }
+
+ override def loadTable(ident: Identifier): Table = {
+ Option(tables.get(ident)) match {
+ case Some(table) =>
+ table
+ case _ =>
+ throw new NoSuchTableException(ident)
+ }
+ }
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ if (partitions.nonEmpty) {
+ throw new UnsupportedOperationException(
+ s"Catalog $name: Partitioned tables are not supported")
+ }
+
+ val table = InMemoryTable(ident.quoted, schema, properties)
+
+ tables.put(ident, table)
+
+ table
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ val table = loadTable(ident)
+ val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes)
+ val schema = TestTableCatalog.applySchemaChanges(table.schema, changes)
+ val newTable = InMemoryTable(table.name, schema, properties)
+
+ tables.put(ident, newTable)
+
+ newTable
+ }
+
+ override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined
+}
+
+object TestTableCatalog {
+ /**
+ * Apply properties changes to a map and return the result.
+ */
+ def applyPropertiesChanges(
+ properties: util.Map[String, String],
+ changes: Seq[TableChange]): util.Map[String, String] = {
+ val newProperties = new util.HashMap[String, String](properties)
+
+ changes.foreach {
+ case set: SetProperty =>
+ newProperties.put(set.property, set.value)
+
+ case unset: RemoveProperty =>
+ newProperties.remove(unset.property)
+
+ case _ =>
+ // ignore non-property changes
+ }
+
+ Collections.unmodifiableMap(newProperties)
+ }
+
+ /**
+ * Apply schema changes to a schema and return the result.
+ */
+ def applySchemaChanges(schema: StructType, changes: Seq[TableChange]): StructType = {
+ changes.foldLeft(schema) { (schema, change) =>
+ change match {
+ case add: AddColumn =>
+ add.fieldNames match {
+ case Array(name) =>
+ val newField = StructField(name, add.dataType, nullable = add.isNullable)
+ Option(add.comment) match {
+ case Some(comment) =>
+ schema.add(newField.withComment(comment))
+ case _ =>
+ schema.add(newField)
+ }
+
+ case names =>
+ replace(schema, names.init, parent => parent.dataType match {
+ case parentType: StructType =>
+ val field = StructField(names.last, add.dataType, nullable = add.isNullable)
+ val newParentType = Option(add.comment) match {
+ case Some(comment) =>
+ parentType.add(field.withComment(comment))
+ case None =>
+ parentType.add(field)
+ }
+
+ Some(StructField(parent.name, newParentType, parent.nullable, parent.metadata))
+
+ case _ =>
+ throw new IllegalArgumentException(s"Not a struct: ${names.init.last}")
+ })
+ }
+
+ case rename: RenameColumn =>
+ replace(schema, rename.fieldNames, field =>
+ Some(StructField(rename.newName, field.dataType, field.nullable, field.metadata)))
+
+ case update: UpdateColumnType =>
+ replace(schema, update.fieldNames, field => {
+ if (!update.isNullable && field.nullable) {
+ throw new IllegalArgumentException(
+ s"Cannot change optional column to required: $field.name")
+ }
+ Some(StructField(field.name, update.newDataType, update.isNullable, field.metadata))
+ })
+
+ case update: UpdateColumnComment =>
+ replace(schema, update.fieldNames, field =>
+ Some(field.withComment(update.newComment)))
+
+ case delete: DeleteColumn =>
+ replace(schema, delete.fieldNames, _ => None)
+
+ case _ =>
+ // ignore non-schema changes
+ schema
+ }
+ }
+ }
+
+ private def replace(
+ struct: StructType,
+ path: Seq[String],
+ update: StructField => Option[StructField]): StructType = {
+
+ val pos = struct.getFieldIndex(path.head)
+ .getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${path.head}"))
+ val field = struct.fields(pos)
+ val replacement: Option[StructField] = if (path.tail.isEmpty) {
+ update(field)
+ } else {
+ field.dataType match {
+ case nestedStruct: StructType =>
+ val updatedType: StructType = replace(nestedStruct, path.tail, update)
+ Some(StructField(field.name, updatedType, field.nullable, field.metadata))
+ case _ =>
+ throw new IllegalArgumentException(s"Not a struct: ${path.head}")
+ }
+ }
+
+ val newFields = struct.fields.zipWithIndex.flatMap {
+ case (_, index) if pos == index =>
+ replacement
+ case (other, _) =>
+ Some(other)
+ }
+
+ new StructType(newFields)
+ }
+}
+
+case class InMemoryTable(
+ name: String,
+ schema: StructType,
+ override val properties: util.Map[String, String]) extends Table {
+ override def partitioning: Array[Transform] = Array.empty
+ override def capabilities: util.Set[TableCapability] = InMemoryTable.CAPABILITIES
+}
+
+object InMemoryTable {
+ val CAPABILITIES: util.Set[TableCapability] = Set.empty[TableCapability].asJava
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
new file mode 100644
index 0000000000000..1ce8852f71bc8
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TestTableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.LogicalExpressions
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode}
+import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class CreateTablePartitioningValidationSuite extends AnalysisTest {
+ import CreateTablePartitioningValidationSuite._
+
+ test("CreateTableAsSelect: fail missing top-level column") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "does_not_exist") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assert(!plan.resolved)
+ assertAnalysisError(plan, Seq(
+ "Invalid partitioning",
+ "does_not_exist is missing or is in a map or array"))
+ }
+
+ test("CreateTableAsSelect: fail missing top-level column nested reference") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "does_not_exist.z") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assert(!plan.resolved)
+ assertAnalysisError(plan, Seq(
+ "Invalid partitioning",
+ "does_not_exist.z is missing or is in a map or array"))
+ }
+
+ test("CreateTableAsSelect: fail missing nested column") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "point.z") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assert(!plan.resolved)
+ assertAnalysisError(plan, Seq(
+ "Invalid partitioning",
+ "point.z is missing or is in a map or array"))
+ }
+
+ test("CreateTableAsSelect: fail with multiple errors") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "does_not_exist", "point.z") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assert(!plan.resolved)
+ assertAnalysisError(plan, Seq(
+ "Invalid partitioning",
+ "point.z is missing or is in a map or array",
+ "does_not_exist is missing or is in a map or array"))
+ }
+
+ test("CreateTableAsSelect: success with top-level column") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "id") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assertAnalysisSuccess(plan)
+ }
+
+ test("CreateTableAsSelect: success using nested column") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "point.x") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assertAnalysisSuccess(plan)
+ }
+
+ test("CreateTableAsSelect: success using complex column") {
+ val plan = CreateTableAsSelect(
+ catalog,
+ Identifier.of(Array(), "table_name"),
+ LogicalExpressions.bucket(4, "point") :: Nil,
+ TestRelation2,
+ Map.empty,
+ Map.empty,
+ ignoreIfExists = false)
+
+ assertAnalysisSuccess(plan)
+ }
+}
+
+private object CreateTablePartitioningValidationSuite {
+ val catalog: TableCatalog = {
+ val cat = new TestTableCatalog()
+ cat.initialize("test", CaseInsensitiveStringMap.empty())
+ cat
+ }
+
+ val schema: StructType = new StructType()
+ .add("id", LongType)
+ .add("data", StringType)
+ .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))
+}
+
+private case object TestRelation2 extends LeafNode with NamedRelation {
+ override def name: String = "source_relation"
+ override def output: Seq[AttributeReference] =
+ CreateTablePartitioningValidationSuite.schema.toAttributes
+}
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala
index 6c899b610ac5b..48b43fcccacef 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala
@@ -19,15 +19,98 @@ package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast}
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project}
-import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+
+class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite {
+ override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ AppendData.byName(table, query)
+ }
+
+ override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ AppendData.byPosition(table, query)
+ }
+}
+
+class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite {
+ override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ OverwritePartitionsDynamic.byName(table, query)
+ }
+
+ override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ OverwritePartitionsDynamic.byPosition(table, query)
+ }
+}
+
+class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite {
+ override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ OverwriteByExpression.byName(table, query, Literal(true))
+ }
+
+ override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = {
+ OverwriteByExpression.byPosition(table, query, Literal(true))
+ }
+
+ test("delete expression is resolved using table fields") {
+ val table = TestRelation(StructType(Seq(
+ StructField("x", DoubleType, nullable = false),
+ StructField("y", DoubleType))).toAttributes)
+
+ val query = TestRelation(StructType(Seq(
+ StructField("a", DoubleType, nullable = false),
+ StructField("b", DoubleType))).toAttributes)
+
+ val a = query.output.head
+ val b = query.output.last
+ val x = table.output.head
+
+ val parsedPlan = OverwriteByExpression.byPosition(table, query,
+ LessThanOrEqual(UnresolvedAttribute(Seq("x")), Literal(15.0d)))
+
+ val expectedPlan = OverwriteByExpression.byPosition(table,
+ Project(Seq(
+ Alias(Cast(a, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(),
+ Alias(Cast(b, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()),
+ query),
+ LessThanOrEqual(
+ AttributeReference("x", DoubleType, nullable = false)(x.exprId),
+ Literal(15.0d)))
+
+ assertNotResolved(parsedPlan)
+ checkAnalysis(parsedPlan, expectedPlan)
+ assertResolved(expectedPlan)
+ }
+
+ test("delete expression is not resolved using query fields") {
+ val xRequiredTable = TestRelation(StructType(Seq(
+ StructField("x", DoubleType, nullable = false),
+ StructField("y", DoubleType))).toAttributes)
+
+ val query = TestRelation(StructType(Seq(
+ StructField("a", DoubleType, nullable = false),
+ StructField("b", DoubleType))).toAttributes)
+
+ // the write is resolved (checked above). this test plan is not because of the expression.
+ val parsedPlan = OverwriteByExpression.byPosition(xRequiredTable, query,
+ LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d)))
+
+ assertNotResolved(parsedPlan)
+ assertAnalysisError(parsedPlan, Seq("cannot resolve", "`a`", "given input columns", "x, y"))
+ }
+}
case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation {
override def name: String = "table-name"
}
-class DataSourceV2AnalysisSuite extends AnalysisTest {
+case class TestRelationAcceptAnySchema(output: Seq[AttributeReference])
+ extends LeafNode with NamedRelation {
+ override def name: String = "test-name"
+ override def skipSchemaResolution: Boolean = true
+}
+
+abstract class DataSourceV2AnalysisSuite extends AnalysisTest {
val table = TestRelation(StructType(Seq(
StructField("x", FloatType),
StructField("y", FloatType))).toAttributes)
@@ -40,21 +123,25 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
StructField("x", DoubleType),
StructField("y", DoubleType))).toAttributes)
- test("Append.byName: basic behavior") {
+ def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan
+
+ def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan
+
+ test("byName: basic behavior") {
val query = TestRelation(table.schema.toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
checkAnalysis(parsedPlan, parsedPlan)
assertResolved(parsedPlan)
}
- test("Append.byName: does not match by position") {
+ test("byName: does not match by position") {
val query = TestRelation(StructType(Seq(
StructField("a", FloatType),
StructField("b", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -62,12 +149,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot find data for output column", "'x'", "'y'"))
}
- test("Append.byName: case sensitive column resolution") {
+ test("byName: case sensitive column resolution") {
val query = TestRelation(StructType(Seq(
StructField("X", FloatType), // doesn't match case!
StructField("y", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -76,7 +163,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
caseSensitive = true)
}
- test("Append.byName: case insensitive column resolution") {
+ test("byName: case insensitive column resolution") {
val query = TestRelation(StructType(Seq(
StructField("X", FloatType), // doesn't match case!
StructField("y", FloatType))).toAttributes)
@@ -84,8 +171,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
val X = query.output.head
val y = query.output.last
- val parsedPlan = AppendData.byName(table, query)
- val expectedPlan = AppendData.byName(table,
+ val parsedPlan = byName(table, query)
+ val expectedPlan = byName(table,
Project(Seq(
Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(),
Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()),
@@ -96,7 +183,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byName: data columns are reordered by name") {
+ test("byName: data columns are reordered by name") {
// out of order
val query = TestRelation(StructType(Seq(
StructField("y", FloatType),
@@ -105,8 +192,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
val y = query.output.head
val x = query.output.last
- val parsedPlan = AppendData.byName(table, query)
- val expectedPlan = AppendData.byName(table,
+ val parsedPlan = byName(table, query)
+ val expectedPlan = byName(table,
Project(Seq(
Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(),
Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()),
@@ -117,26 +204,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byName: fail nullable data written to required columns") {
- val parsedPlan = AppendData.byName(requiredTable, table)
+ test("byName: fail nullable data written to required columns") {
+ val parsedPlan = byName(requiredTable, table)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot write nullable values to non-null column", "'x'", "'y'"))
}
- test("Append.byName: allow required data written to nullable columns") {
- val parsedPlan = AppendData.byName(table, requiredTable)
+ test("byName: allow required data written to nullable columns") {
+ val parsedPlan = byName(table, requiredTable)
assertResolved(parsedPlan)
checkAnalysis(parsedPlan, parsedPlan)
}
- test("Append.byName: missing required columns cause failure and are identified by name") {
+ test("byName: missing required columns cause failure and are identified by name") {
// missing required field x
val query = TestRelation(StructType(Seq(
StructField("y", FloatType, nullable = false))).toAttributes)
- val parsedPlan = AppendData.byName(requiredTable, query)
+ val parsedPlan = byName(requiredTable, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -144,12 +231,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot find data for output column", "'x'"))
}
- test("Append.byName: missing optional columns cause failure and are identified by name") {
+ test("byName: missing optional columns cause failure and are identified by name") {
// missing optional field x
val query = TestRelation(StructType(Seq(
StructField("y", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -157,8 +244,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot find data for output column", "'x'"))
}
- test("Append.byName: fail canWrite check") {
- val parsedPlan = AppendData.byName(table, widerTable)
+ test("byName: fail canWrite check") {
+ val parsedPlan = byName(table, widerTable)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -166,12 +253,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}
- test("Append.byName: insert safe cast") {
+ test("byName: insert safe cast") {
val x = table.output.head
val y = table.output.last
- val parsedPlan = AppendData.byName(widerTable, table)
- val expectedPlan = AppendData.byName(widerTable,
+ val parsedPlan = byName(widerTable, table)
+ val expectedPlan = byName(widerTable,
Project(Seq(
Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(),
Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()),
@@ -182,13 +269,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byName: fail extra data fields") {
+ test("byName: fail extra data fields") {
val query = TestRelation(StructType(Seq(
StructField("x", FloatType),
StructField("y", FloatType),
StructField("z", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -197,7 +284,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'x', 'y', 'z'"))
}
- test("Append.byName: multiple field errors are reported") {
+ test("byName: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)
@@ -206,7 +293,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(xRequiredTable, query)
+ val parsedPlan = byName(xRequiredTable, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -216,7 +303,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot find data for output column", "'y'"))
}
- test("Append.byPosition: basic behavior") {
+ test("byPosition: basic behavior") {
val query = TestRelation(StructType(Seq(
StructField("a", FloatType),
StructField("b", FloatType))).toAttributes)
@@ -224,8 +311,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
val a = query.output.head
val b = query.output.last
- val parsedPlan = AppendData.byPosition(table, query)
- val expectedPlan = AppendData.byPosition(table,
+ val parsedPlan = byPosition(table, query)
+ val expectedPlan = byPosition(table,
Project(Seq(
Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(),
Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()),
@@ -236,7 +323,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byPosition: data columns are not reordered") {
+ test("byPosition: data columns are not reordered") {
// out of order
val query = TestRelation(StructType(Seq(
StructField("y", FloatType),
@@ -245,8 +332,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
val y = query.output.head
val x = query.output.last
- val parsedPlan = AppendData.byPosition(table, query)
- val expectedPlan = AppendData.byPosition(table,
+ val parsedPlan = byPosition(table, query)
+ val expectedPlan = byPosition(table,
Project(Seq(
Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(),
Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()),
@@ -257,26 +344,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byPosition: fail nullable data written to required columns") {
- val parsedPlan = AppendData.byPosition(requiredTable, table)
+ test("byPosition: fail nullable data written to required columns") {
+ val parsedPlan = byPosition(requiredTable, table)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
"Cannot write incompatible data to table", "'table-name'",
"Cannot write nullable values to non-null column", "'x'", "'y'"))
}
- test("Append.byPosition: allow required data written to nullable columns") {
- val parsedPlan = AppendData.byPosition(table, requiredTable)
+ test("byPosition: allow required data written to nullable columns") {
+ val parsedPlan = byPosition(table, requiredTable)
assertResolved(parsedPlan)
checkAnalysis(parsedPlan, parsedPlan)
}
- test("Append.byPosition: missing required columns cause failure") {
+ test("byPosition: missing required columns cause failure") {
// missing optional field x
val query = TestRelation(StructType(Seq(
StructField("y", FloatType, nullable = false))).toAttributes)
- val parsedPlan = AppendData.byPosition(requiredTable, query)
+ val parsedPlan = byPosition(requiredTable, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -285,12 +372,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'y'"))
}
- test("Append.byPosition: missing optional columns cause failure") {
+ test("byPosition: missing optional columns cause failure") {
// missing optional field x
val query = TestRelation(StructType(Seq(
StructField("y", FloatType))).toAttributes)
- val parsedPlan = AppendData.byPosition(table, query)
+ val parsedPlan = byPosition(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -299,12 +386,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'y'"))
}
- test("Append.byPosition: fail canWrite check") {
+ test("byPosition: fail canWrite check") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
StructField("b", DoubleType))).toAttributes)
- val parsedPlan = AppendData.byPosition(table, widerTable)
+ val parsedPlan = byPosition(table, widerTable)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -312,7 +399,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType"))
}
- test("Append.byPosition: insert safe cast") {
+ test("byPosition: insert safe cast") {
val widerTable = TestRelation(StructType(Seq(
StructField("a", DoubleType),
StructField("b", DoubleType))).toAttributes)
@@ -320,8 +407,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
val x = table.output.head
val y = table.output.last
- val parsedPlan = AppendData.byPosition(widerTable, table)
- val expectedPlan = AppendData.byPosition(widerTable,
+ val parsedPlan = byPosition(widerTable, table)
+ val expectedPlan = byPosition(widerTable,
Project(Seq(
Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(),
Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()),
@@ -332,13 +419,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
assertResolved(expectedPlan)
}
- test("Append.byPosition: fail extra data fields") {
+ test("byPosition: fail extra data fields") {
val query = TestRelation(StructType(Seq(
StructField("a", FloatType),
StructField("b", FloatType),
StructField("c", FloatType))).toAttributes)
- val parsedPlan = AppendData.byName(table, query)
+ val parsedPlan = byName(table, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -347,7 +434,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Data columns: 'a', 'b', 'c'"))
}
- test("Append.byPosition: multiple field errors are reported") {
+ test("byPosition: multiple field errors are reported") {
val xRequiredTable = TestRelation(StructType(Seq(
StructField("x", FloatType, nullable = false),
StructField("y", DoubleType))).toAttributes)
@@ -356,7 +443,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
StructField("x", DoubleType),
StructField("b", FloatType))).toAttributes)
- val parsedPlan = AppendData.byPosition(xRequiredTable, query)
+ val parsedPlan = byPosition(xRequiredTable, query)
assertNotResolved(parsedPlan)
assertAnalysisError(parsedPlan, Seq(
@@ -365,6 +452,27 @@ class DataSourceV2AnalysisSuite extends AnalysisTest {
"Cannot safely cast", "'x'", "DoubleType to FloatType"))
}
+ test("bypass output column resolution") {
+ val table = TestRelationAcceptAnySchema(StructType(Seq(
+ StructField("a", FloatType, nullable = false),
+ StructField("b", DoubleType))).toAttributes)
+
+ val query = TestRelation(StructType(Seq(
+ StructField("s", StringType))).toAttributes)
+
+ withClue("byName") {
+ val parsedPlan = byName(table, query)
+ assertResolved(parsedPlan)
+ checkAnalysis(parsedPlan, parsedPlan)
+ }
+
+ withClue("byPosition") {
+ val parsedPlan = byPosition(table, query)
+ assertResolved(parsedPlan)
+ checkAnalysis(parsedPlan, parsedPlan)
+ }
+ }
+
def assertNotResolved(logicalPlan: LogicalPlan): Unit = {
assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan")
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala
new file mode 100644
index 0000000000000..783751ff79865
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.catalog.v2
+
+import org.scalatest.Inside
+import org.scalatest.Matchers._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin {
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit
+}
+
+class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside {
+ import CatalystSqlParser._
+
+ private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap
+
+ override def lookupCatalog(name: String): CatalogPlugin =
+ catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found"))
+
+ test("catalog object identifier") {
+ Seq(
+ ("tbl", None, Seq.empty, "tbl"),
+ ("db.tbl", None, Seq("db"), "tbl"),
+ ("prod.func", catalogs.get("prod"), Seq.empty, "func"),
+ ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"),
+ ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"),
+ ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"),
+ ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"),
+ ("`db.tbl`", None, Seq.empty, "db.tbl"),
+ ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"),
+ ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None,
+ Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach {
+ case (sql, expectedCatalog, namespace, name) =>
+ inside(parseMultipartIdentifier(sql)) {
+ case CatalogObjectIdentifier(catalog, ident) =>
+ catalog shouldEqual expectedCatalog
+ ident shouldEqual Identifier.of(namespace.toArray, name)
+ }
+ }
+ }
+
+ test("table identifier") {
+ Seq(
+ ("tbl", "tbl", None),
+ ("db.tbl", "tbl", Some("db")),
+ ("`db.tbl`", "db.tbl", None),
+ ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")),
+ ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json",
+ Some("org.apache.spark.sql.json"))).foreach {
+ case (sql, table, db) =>
+ inside (parseMultipartIdentifier(sql)) {
+ case AsTableIdentifier(ident) =>
+ ident shouldEqual TableIdentifier(table, db)
+ }
+ }
+ Seq(
+ "prod.func",
+ "prod.db.tbl",
+ "ns1.ns2.tbl").foreach { sql =>
+ parseMultipartIdentifier(sql) match {
+ case AsTableIdentifier(_) =>
+ fail(s"$sql should not be resolved as TableIdentifier")
+ case _ =>
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
new file mode 100644
index 0000000000000..35cd813ae65c5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala
@@ -0,0 +1,397 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform}
+import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.catalog.BucketSpec
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class DDLParserSuite extends AnalysisTest {
+ import CatalystSqlParser._
+
+ private def intercept(sqlCommand: String, messages: String*): Unit = {
+ val e = intercept[ParseException](parsePlan(sqlCommand))
+ messages.foreach { message =>
+ assert(e.message.contains(message))
+ }
+ }
+
+ private def parseCompare(sql: String, expected: LogicalPlan): Unit = {
+ comparePlans(parsePlan(sql), expected, checkAnalysis = false)
+ }
+
+ test("create table using - schema") {
+ val sql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType()
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+
+ intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet",
+ "no viable alternative at input")
+ }
+
+ test("create table - with IF NOT EXISTS") {
+ val sql = "CREATE TABLE IF NOT EXISTS my_tab(a INT, b STRING) USING parquet"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with partitioned by") {
+ val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " +
+ "USING parquet PARTITIONED BY (a)"
+
+ parsePlan(query) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType()
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType))
+ assert(create.partitioning == Seq(IdentityTransform(FieldReference("a"))))
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $query")
+ }
+ }
+
+ test("create table - partitioned by transforms") {
+ val sql =
+ """
+ |CREATE TABLE my_tab (a INT, b STRING, ts TIMESTAMP) USING parquet
+ |PARTITIONED BY (
+ | a,
+ | bucket(16, b),
+ | years(ts),
+ | months(ts),
+ | days(ts),
+ | hours(ts),
+ | foo(a, "bar", 34))
+ """.stripMargin
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType()
+ .add("a", IntegerType)
+ .add("b", StringType)
+ .add("ts", TimestampType))
+ assert(create.partitioning == Seq(
+ IdentityTransform(FieldReference("a")),
+ BucketTransform(LiteralValue(16, IntegerType), Seq(FieldReference("b"))),
+ YearsTransform(FieldReference("ts")),
+ MonthsTransform(FieldReference("ts")),
+ DaysTransform(FieldReference("ts")),
+ HoursTransform(FieldReference("ts")),
+ ApplyTransform("foo", Seq(
+ FieldReference("a"),
+ LiteralValue(UTF8String.fromString("bar"), StringType),
+ LiteralValue(34, IntegerType)))))
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with bucket") {
+ val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " +
+ "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS"
+
+ parsePlan(query) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.contains(BucketSpec(5, Seq("a"), Seq("b"))))
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $query")
+ }
+ }
+
+ test("create table - with comment") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.contains("abc"))
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with table properties") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties == Map("test" -> "test"))
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with location") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("my_tab"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.contains("/tmp/file"))
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - byte length literal table name") {
+ val sql = "CREATE TABLE 1m.2g(a INT) USING parquet"
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("1m", "2g"))
+ assert(create.tableSchema == new StructType().add("a", IntegerType))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Duplicate clauses - create table") {
+ def createTableHeader(duplicateClause: String): String = {
+ s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause"
+ }
+
+ intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"),
+ "Found duplicate clauses: TBLPROPERTIES")
+ intercept(createTableHeader("LOCATION '/tmp/file'"),
+ "Found duplicate clauses: LOCATION")
+ intercept(createTableHeader("COMMENT 'a table'"),
+ "Found duplicate clauses: COMMENT")
+ intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"),
+ "Found duplicate clauses: CLUSTERED BY")
+ intercept(createTableHeader("PARTITIONED BY (b)"),
+ "Found duplicate clauses: PARTITIONED BY")
+ }
+
+ test("support for other types in OPTIONS") {
+ val sql =
+ """
+ |CREATE TABLE table_name USING json
+ |OPTIONS (a 1, b 0.1, c TRUE)
+ """.stripMargin
+
+ parsePlan(sql) match {
+ case create: CreateTableStatement =>
+ assert(create.tableName == Seq("table_name"))
+ assert(create.tableSchema == new StructType)
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties.isEmpty)
+ assert(create.provider == "json")
+ assert(create.options == Map("a" -> "1", "b" -> "0.1", "c" -> "true"))
+ assert(create.location.isEmpty)
+ assert(create.comment.isEmpty)
+ assert(!create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Test CTAS against native tables") {
+ val s1 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val s2 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |LOCATION '/user/external/page_view'
+ |COMMENT 'This is the staging page view table'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val s3 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ checkParsing(s1)
+ checkParsing(s2)
+ checkParsing(s3)
+
+ def checkParsing(sql: String): Unit = {
+ parsePlan(sql) match {
+ case create: CreateTableAsSelectStatement =>
+ assert(create.tableName == Seq("mydb", "page_view"))
+ assert(create.partitioning.isEmpty)
+ assert(create.bucketSpec.isEmpty)
+ assert(create.properties == Map("p1" -> "v1", "p2" -> "v2"))
+ assert(create.provider == "parquet")
+ assert(create.options.isEmpty)
+ assert(create.location.contains("/user/external/page_view"))
+ assert(create.comment.contains("This is the staging page view table"))
+ assert(create.ifNotExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableAsSelectStatement].getClass.getName} " +
+ s"from query, got ${other.getClass.getName}: $sql")
+ }
+ }
+ }
+
+ test("drop table") {
+ parseCompare("DROP TABLE testcat.ns1.ns2.tbl",
+ DropTableStatement(Seq("testcat", "ns1", "ns2", "tbl"), ifExists = false, purge = false))
+ parseCompare(s"DROP TABLE db.tab",
+ DropTableStatement(Seq("db", "tab"), ifExists = false, purge = false))
+ parseCompare(s"DROP TABLE IF EXISTS db.tab",
+ DropTableStatement(Seq("db", "tab"), ifExists = true, purge = false))
+ parseCompare(s"DROP TABLE tab",
+ DropTableStatement(Seq("tab"), ifExists = false, purge = false))
+ parseCompare(s"DROP TABLE IF EXISTS tab",
+ DropTableStatement(Seq("tab"), ifExists = true, purge = false))
+ parseCompare(s"DROP TABLE tab PURGE",
+ DropTableStatement(Seq("tab"), ifExists = false, purge = true))
+ parseCompare(s"DROP TABLE IF EXISTS tab PURGE",
+ DropTableStatement(Seq("tab"), ifExists = true, purge = true))
+ }
+
+ test("drop view") {
+ parseCompare(s"DROP VIEW testcat.db.view",
+ DropViewStatement(Seq("testcat", "db", "view"), ifExists = false))
+ parseCompare(s"DROP VIEW db.view", DropViewStatement(Seq("db", "view"), ifExists = false))
+ parseCompare(s"DROP VIEW IF EXISTS db.view",
+ DropViewStatement(Seq("db", "view"), ifExists = true))
+ parseCompare(s"DROP VIEW view", DropViewStatement(Seq("view"), ifExists = false))
+ parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
index ff0de0fb7c1f0..489b7f328f8fa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala
@@ -47,8 +47,8 @@ class TableIdentifierParserSuite extends SparkFunSuite {
"cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external",
"false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in",
"insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null",
- "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke",
- "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
+ "of", "order", "out", "outer", "partition", "percent", "procedure", "query", "range", "reads",
+ "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger",
"true", "truncate", "update", "user", "values", "with", "regexp", "rlike",
"bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float",
"int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
similarity index 96%
rename from sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index d801f62b62323..4439a7bb3ae87 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -15,9 +15,9 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.arrow
+package org.apache.spark.sql.util
-import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeUtils
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala
similarity index 53%
rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala
index cfa69a86de1a7..0accb471cada3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala
@@ -15,31 +15,38 @@
* limitations under the License.
*/
-package org.apache.spark.sql.sources.v2
+package org.apache.spark.sql.util
+
+import java.util
import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
-/**
- * A simple test suite to verify `DataSourceOptions`.
- */
-class DataSourceOptionsSuite extends SparkFunSuite {
+class CaseInsensitiveStringMapSuite extends SparkFunSuite {
+
+ test("put and get") {
+ val options = CaseInsensitiveStringMap.empty()
+ intercept[UnsupportedOperationException] {
+ options.put("kEy", "valUE")
+ }
+ }
- test("key is case-insensitive") {
- val options = new DataSourceOptions(Map("foo" -> "bar").asJava)
- assert(options.get("foo").get() == "bar")
- assert(options.get("FoO").get() == "bar")
- assert(!options.get("abc").isPresent)
+ test("clear") {
+ val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava)
+ intercept[UnsupportedOperationException] {
+ options.clear()
+ }
}
- test("value is case-sensitive") {
- val options = new DataSourceOptions(Map("foo" -> "bAr").asJava)
- assert(options.get("foo").get == "bAr")
+ test("key and value set") {
+ val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava)
+ assert(options.keySet().asScala == Set("key"))
+ assert(options.values().asScala.toSeq == Seq("valUE"))
}
test("getInt") {
- val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava)
+ val options = new CaseInsensitiveStringMap(Map("numFOo" -> "1", "foo" -> "bar").asJava)
assert(options.getInt("numFOO", 10) == 1)
assert(options.getInt("numFOO2", 10) == 10)
@@ -49,17 +56,20 @@ class DataSourceOptionsSuite extends SparkFunSuite {
}
test("getBoolean") {
- val options = new DataSourceOptions(
+ val options = new CaseInsensitiveStringMap(
Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava)
assert(options.getBoolean("isFoo", false))
assert(!options.getBoolean("isFoo2", true))
assert(options.getBoolean("isBar", true))
assert(!options.getBoolean("isBar", false))
- assert(!options.getBoolean("FOO", true))
+
+ intercept[IllegalArgumentException] {
+ options.getBoolean("FOO", true)
+ }
}
test("getLong") {
- val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807",
+ val options = new CaseInsensitiveStringMap(Map("numFoo" -> "9223372036854775807",
"foo" -> "bar").asJava)
assert(options.getLong("numFOO", 0L) == 9223372036854775807L)
assert(options.getLong("numFoo2", -1L) == -1L)
@@ -70,7 +80,7 @@ class DataSourceOptionsSuite extends SparkFunSuite {
}
test("getDouble") {
- val options = new DataSourceOptions(Map("numFoo" -> "922337.1",
+ val options = new CaseInsensitiveStringMap(Map("numFoo" -> "922337.1",
"foo" -> "bar").asJava)
assert(options.getDouble("numFOO", 0d) == 922337.1d)
assert(options.getDouble("numFoo2", -1.02d) == -1.02d)
@@ -80,28 +90,19 @@ class DataSourceOptionsSuite extends SparkFunSuite {
}
}
- test("standard options") {
- val options = new DataSourceOptions(Map(
- DataSourceOptions.PATH_KEY -> "abc",
- DataSourceOptions.TABLE_KEY -> "tbl").asJava)
-
- assert(options.paths().toSeq == Seq("abc"))
- assert(options.tableName().get() == "tbl")
- assert(!options.databaseName().isPresent)
- }
-
- test("standard options with both singular path and multi-paths") {
- val options = new DataSourceOptions(Map(
- DataSourceOptions.PATH_KEY -> "abc",
- DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava)
-
- assert(options.paths().toSeq == Seq("abc", "c", "d"))
- }
-
- test("standard options with only multi-paths") {
- val options = new DataSourceOptions(Map(
- DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava)
+ test("asCaseSensitiveMap") {
+ val originalMap = new util.HashMap[String, String] {
+ put("Foo", "Bar")
+ put("OFO", "ABR")
+ put("OoF", "bar")
+ }
- assert(options.paths().toSeq == Seq("c", "d\"e"))
+ val options = new CaseInsensitiveStringMap(originalMap)
+ val caseSensitiveMap = options.asCaseSensitiveMap
+ assert(caseSensitiveMap.equals(originalMap))
+ // The result of `asCaseSensitiveMap` is read-only.
+ intercept[UnsupportedOperationException] {
+ caseSensitiveMap.put("kEy", "valUE")
+ }
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 95e98c5444721..6f0db3632d7dd 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -112,10 +112,6 @@
com.fasterxml.jackson.core
jackson-databind
-
- org.apache.arrow
- arrow-vector
-
org.apache.xbean
xbean-asm7-shaded
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java
similarity index 68%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
rename to sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java
index c44b8af2552f0..7c167dc012329 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java
@@ -18,12 +18,10 @@
package org.apache.spark.sql.execution.streaming;
/**
- * The shared interface between V1 streaming sources and V2 streaming readers.
+ * This class is an alias of {@link org.apache.spark.sql.sources.v2.reader.streaming.Offset}. It's
+ * internal and deprecated. New streaming data source implementations should use data source v2 API,
+ * which will be supported in the long term.
*
- * This is a temporary interface for compatibility during migration. It should not be implemented
- * directly, and will be removed in future versions.
+ * This class will be removed in a future release.
*/
-public interface BaseStreamingSource {
- /** Stop this source and free any resources it has allocated. */
- void stop();
-}
+public abstract class Offset extends org.apache.spark.sql.sources.v2.reader.streaming.Offset {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index 4e4242fe8d9b9..fca7e36859126 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -26,7 +26,6 @@
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.sql.vectorized.ColumnarRow;
-import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -39,17 +38,10 @@
*/
public final class MutableColumnarRow extends InternalRow {
public int rowId;
- private final ColumnVector[] columns;
- private final WritableColumnVector[] writableColumns;
-
- public MutableColumnarRow(ColumnVector[] columns) {
- this.columns = columns;
- this.writableColumns = null;
- }
+ private final WritableColumnVector[] columns;
public MutableColumnarRow(WritableColumnVector[] writableColumns) {
this.columns = writableColumns;
- this.writableColumns = writableColumns;
}
@Override
@@ -228,54 +220,54 @@ public void update(int ordinal, Object value) {
@Override
public void setNullAt(int ordinal) {
- writableColumns[ordinal].putNull(rowId);
+ columns[ordinal].putNull(rowId);
}
@Override
public void setBoolean(int ordinal, boolean value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putBoolean(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putBoolean(rowId, value);
}
@Override
public void setByte(int ordinal, byte value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putByte(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putByte(rowId, value);
}
@Override
public void setShort(int ordinal, short value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putShort(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putShort(rowId, value);
}
@Override
public void setInt(int ordinal, int value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putInt(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putInt(rowId, value);
}
@Override
public void setLong(int ordinal, long value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putLong(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putLong(rowId, value);
}
@Override
public void setFloat(int ordinal, float value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putFloat(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putFloat(rowId, value);
}
@Override
public void setDouble(int ordinal, double value) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putDouble(rowId, value);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDouble(rowId, value);
}
@Override
public void setDecimal(int ordinal, Decimal value, int precision) {
- writableColumns[ordinal].putNotNull(rowId);
- writableColumns[ordinal].putDecimal(rowId, value, precision);
+ columns[ordinal].putNotNull(rowId);
+ columns[ordinal].putDecimal(rowId, value, precision);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
deleted file mode 100644
index 00af0bf1b172c..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Locale;
-import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Stream;
-
-import com.fasterxml.jackson.databind.ObjectMapper;
-
-import org.apache.spark.annotation.Evolving;
-
-/**
- * An immutable string-to-string map in which keys are case-insensitive. This is used to represent
- * data source options.
- *
- * Each data source implementation can define its own options and teach its users how to set them.
- * Spark doesn't have any restrictions about what options a data source should or should not have.
- * Instead Spark defines some standard options that data sources can optionally adopt. It's possible
- * that some options are very common and many data sources use them. However different data
- * sources may define the common options(key and meaning) differently, which is quite confusing to
- * end users.
- *
- * The standard options defined by Spark:
- *
- *
- * Option key |
- * Option value |
- *
- *
- * path |
- * A path string of the data files/directories, like
- * path1 , /absolute/file2 , path3/* . The path can
- * either be relative or absolute, points to either file or directory, and can contain
- * wildcards. This option is commonly used by file-based data sources. |
- *
- *
- * paths |
- * A JSON array style paths string of the data files/directories, like
- * ["path1", "/absolute/file2"] . The format of each path is same as the
- * path option, plus it should follow JSON string literal format, e.g. quotes
- * should be escaped, pa\"th means pa"th.
- * |
- *
- *
- * table |
- * A table name string representing the table name directly without any interpretation.
- * For example, db.tbl means a table called db.tbl, not a table called tbl
- * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l. |
- *
- *
- * database |
- * A database name string representing the database name directly without any
- * interpretation, which is very similar to the table name option. |
- *
- *
- */
-@Evolving
-public class DataSourceOptions {
- private final Map keyLowerCasedMap;
-
- private String toLowerCase(String key) {
- return key.toLowerCase(Locale.ROOT);
- }
-
- public static DataSourceOptions empty() {
- return new DataSourceOptions(new HashMap<>());
- }
-
- public DataSourceOptions(Map originalMap) {
- keyLowerCasedMap = new HashMap<>(originalMap.size());
- for (Map.Entry entry : originalMap.entrySet()) {
- keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue());
- }
- }
-
- public Map asMap() {
- return new HashMap<>(keyLowerCasedMap);
- }
-
- /**
- * Returns the option value to which the specified key is mapped, case-insensitively.
- */
- public Optional get(String key) {
- return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key)));
- }
-
- /**
- * Returns the boolean value to which the specified key is mapped,
- * or defaultValue if there is no mapping for the key. The key match is case-insensitive
- */
- public boolean getBoolean(String key, boolean defaultValue) {
- String lcaseKey = toLowerCase(key);
- return keyLowerCasedMap.containsKey(lcaseKey) ?
- Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
- }
-
- /**
- * Returns the integer value to which the specified key is mapped,
- * or defaultValue if there is no mapping for the key. The key match is case-insensitive
- */
- public int getInt(String key, int defaultValue) {
- String lcaseKey = toLowerCase(key);
- return keyLowerCasedMap.containsKey(lcaseKey) ?
- Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
- }
-
- /**
- * Returns the long value to which the specified key is mapped,
- * or defaultValue if there is no mapping for the key. The key match is case-insensitive
- */
- public long getLong(String key, long defaultValue) {
- String lcaseKey = toLowerCase(key);
- return keyLowerCasedMap.containsKey(lcaseKey) ?
- Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
- }
-
- /**
- * Returns the double value to which the specified key is mapped,
- * or defaultValue if there is no mapping for the key. The key match is case-insensitive
- */
- public double getDouble(String key, double defaultValue) {
- String lcaseKey = toLowerCase(key);
- return keyLowerCasedMap.containsKey(lcaseKey) ?
- Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue;
- }
-
- /**
- * The option key for singular path.
- */
- public static final String PATH_KEY = "path";
-
- /**
- * The option key for multiple paths.
- */
- public static final String PATHS_KEY = "paths";
-
- /**
- * The option key for table name.
- */
- public static final String TABLE_KEY = "table";
-
- /**
- * The option key for database name.
- */
- public static final String DATABASE_KEY = "database";
-
- /**
- * The option key for whether to check existence of files for a table.
- */
- public static final String CHECK_FILES_EXIST_KEY = "check_files_exist";
-
- /**
- * Returns all the paths specified by both the singular path option and the multiple
- * paths option.
- */
- public String[] paths() {
- String[] singularPath =
- get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]);
- Optional pathsStr = get(PATHS_KEY);
- if (pathsStr.isPresent()) {
- ObjectMapper objectMapper = new ObjectMapper();
- try {
- String[] paths = objectMapper.readValue(pathsStr.get(), String[].class);
- return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new);
- } catch (IOException e) {
- return singularPath;
- }
- } else {
- return singularPath;
- }
- }
-
- /**
- * Returns the value of the table name option.
- */
- public Optional tableName() {
- return get(TABLE_KEY);
- }
-
- /**
- * Returns the value of the database name option.
- */
- public Optional databaseName() {
- return get(DATABASE_KEY);
- }
-
- public Boolean checkFilesExist() {
- Optional result = get(CHECK_FILES_EXIST_KEY);
- return result.isPresent() && result.get().equals("true");
- }
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java
deleted file mode 100644
index 8ac9c51750865..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.execution.streaming.BaseStreamingSink;
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport;
-import org.apache.spark.sql.streaming.OutputMode;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to
- * provide data writing ability for structured streaming.
- *
- * This interface is used to create {@link StreamingWriteSupport} instances when end users run
- * {@code Dataset.writeStream.format(...).option(...).start()}.
- */
-@Evolving
-public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink {
-
- /**
- * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is
- * called by Spark at the beginning of each streaming query.
- *
- * @param queryId A unique string for the writing query. It's possible that there are many
- * writing queries running at the same time, and the returned
- * {@link StreamingWriteSupport} can use this id to distinguish itself from others.
- * @param schema the schema of the data to be written.
- * @param mode the output mode which determines what successive epoch output means to this
- * sink, please refer to {@link OutputMode} for more details.
- * @param options the options for the returned data source writer, which is an immutable
- * case-insensitive string-to-string map.
- */
- StreamingWriteSupport createStreamingWriteSupport(
- String queryId,
- StructType schema,
- OutputMode mode,
- DataSourceOptions options);
-}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
deleted file mode 100644
index 6c5a95d2a75b7..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.reader.Scan;
-import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
-
-/**
- * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan.
- *
- * If a {@link Table} implements this interface, the
- * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that
- * builds {@link Scan} with {@link Scan#toBatch()} implemented.
- *
- */
-@Evolving
-public interface SupportsBatchRead extends SupportsRead { }
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java
deleted file mode 100644
index 08caadd5308e6..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.sources.v2;
-
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.writer.WriteBuilder;
-
-/**
- * An empty mix-in interface for {@link Table}, to indicate this table supports batch write.
- *
- * If a {@link Table} implements this interface, the
- * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder}
- * with {@link WriteBuilder#buildForBatch()} implemented.
- *
- */
-@Evolving
-public interface SupportsBatchWrite extends SupportsWrite {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
deleted file mode 100644
index 07546a54013ec..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java
+++ /dev/null
@@ -1,113 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.sql.vectorized;
-
-import java.util.*;
-
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.execution.vectorized.MutableColumnarRow;
-
-/**
- * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this
- * batch so that Spark can access the data row by row. Instance of it is meant to be reused during
- * the entire data loading process.
- */
-@Evolving
-public final class ColumnarBatch {
- private int numRows;
- private final ColumnVector[] columns;
-
- // Staging row returned from `getRow`.
- private final MutableColumnarRow row;
-
- /**
- * Called to close all the columns in this batch. It is not valid to access the data after
- * calling this. This must be called at the end to clean up memory allocations.
- */
- public void close() {
- for (ColumnVector c: columns) {
- c.close();
- }
- }
-
- /**
- * Returns an iterator over the rows in this batch.
- */
- public Iterator rowIterator() {
- final int maxRows = numRows;
- final MutableColumnarRow row = new MutableColumnarRow(columns);
- return new Iterator() {
- int rowId = 0;
-
- @Override
- public boolean hasNext() {
- return rowId < maxRows;
- }
-
- @Override
- public InternalRow next() {
- if (rowId >= maxRows) {
- throw new NoSuchElementException();
- }
- row.rowId = rowId++;
- return row;
- }
-
- @Override
- public void remove() {
- throw new UnsupportedOperationException();
- }
- };
- }
-
- /**
- * Sets the number of rows in this batch.
- */
- public void setNumRows(int numRows) {
- this.numRows = numRows;
- }
-
- /**
- * Returns the number of columns that make up this batch.
- */
- public int numCols() { return columns.length; }
-
- /**
- * Returns the number of rows for read, including filtered rows.
- */
- public int numRows() { return numRows; }
-
- /**
- * Returns the column at `ordinal`.
- */
- public ColumnVector column(int ordinal) { return columns[ordinal]; }
-
- /**
- * Returns the row in this batch at `rowId`. Returned row is reused across calls.
- */
- public InternalRow getRow(int rowId) {
- assert(rowId >= 0 && rowId < numRows);
- row.rowId = rowId;
- return row;
- }
-
- public ColumnarBatch(ColumnVector[] columns) {
- this.columns = columns;
- this.row = new MutableColumnarRow(columns);
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a380a06cb942b..0cf9957539e73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -37,9 +37,11 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.csv._
import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, FileTable}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2}
import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String
/**
@@ -176,7 +178,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
*/
def load(path: String): DataFrame = {
// force invocation of `load(...varargs...)`
- option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*)
+ option("path", path).load(Seq.empty: _*)
}
/**
@@ -193,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
val useV1Sources =
- sparkSession.sessionState.conf.userV1SourceReaderList.toLowerCase(Locale.ROOT).split(",")
+ sparkSession.sessionState.conf.useV1SourceReaderList.toLowerCase(Locale.ROOT).split(",")
val lookupCls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
val cls = lookupCls.newInstance() match {
case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) ||
@@ -205,21 +207,25 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
if (classOf[TableProvider].isAssignableFrom(cls)) {
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- ds = provider, conf = sparkSession.sessionState.conf)
- val pathsOption = {
+ source = provider, conf = sparkSession.sessionState.conf)
+ val pathsOption = if (paths.isEmpty) {
+ None
+ } else {
val objectMapper = new ObjectMapper()
- DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
+ Some("paths" -> objectMapper.writeValueAsString(paths.toArray))
}
- val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "true"
- val finalOptions = sessionOptions ++ extraOptions.toMap + pathsOption + checkFilesExistsOption
- val dsOptions = new DataSourceOptions(finalOptions.asJava)
+ // TODO SPARK-27113: remove this option.
+ val checkFilesExistsOpt = "check_files_exist" -> "true"
+ val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption + checkFilesExistsOpt
+ val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
val table = userSpecifiedSchema match {
case Some(schema) => provider.getTable(dsOptions, schema)
case _ => provider.getTable(dsOptions)
}
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
- case _: SupportsBatchRead =>
- Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, finalOptions))
+ case _: SupportsRead if table.supports(BATCH_READ) =>
+ Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions))
case _ => loadV1Source(paths: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 47fb548ecd43c..b87b3bd4f0761 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -25,15 +25,18 @@ import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan}
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.DDLUtils
-import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.v2._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* Interface used to write a [[Dataset]] to external storage systems (e.g. file systems,
@@ -52,13 +55,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `SaveMode.Overwrite`: overwrite the existing data.
* `SaveMode.Append`: append the data.
* `SaveMode.Ignore`: ignore the operation (i.e. no-op).
- * `SaveMode.ErrorIfExists`: default option, throw an exception at runtime.
+ * `SaveMode.ErrorIfExists`: throw an exception at runtime.
*
+ *
+ * When writing to data source v1, the default option is `ErrorIfExists`. When writing to data
+ * source v2, the default option is `Append`.
*
* @since 1.4.0
*/
def mode(saveMode: SaveMode): DataFrameWriter[T] = {
- this.mode = saveMode
+ this.mode = Some(saveMode)
this
}
@@ -74,15 +80,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @since 1.4.0
*/
def mode(saveMode: String): DataFrameWriter[T] = {
- this.mode = saveMode.toLowerCase(Locale.ROOT) match {
- case "overwrite" => SaveMode.Overwrite
- case "append" => SaveMode.Append
- case "ignore" => SaveMode.Ignore
- case "error" | "errorifexists" | "default" => SaveMode.ErrorIfExists
+ saveMode.toLowerCase(Locale.ROOT) match {
+ case "overwrite" => mode(SaveMode.Overwrite)
+ case "append" => mode(SaveMode.Append)
+ case "ignore" => mode(SaveMode.Ignore)
+ case "error" | "errorifexists" => mode(SaveMode.ErrorIfExists)
+ case "default" => this
case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " +
"Accepted save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists'.")
}
- this
}
/**
@@ -244,7 +250,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val session = df.sparkSession
val useV1Sources =
- session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
+ session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",")
val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf)
val cls = lookupCls.newInstance() match {
case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) ||
@@ -259,36 +265,48 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
provider, session.sessionState.conf)
- val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "false"
+ // TODO SPARK-27113: remove this option.
+ val checkFilesExistsOption = "check_files_exist" -> "false"
val options = sessionOptions ++ extraOptions + checkFilesExistsOption
- val dsOptions = new DataSourceOptions(options.asJava)
+ val dsOptions = new CaseInsensitiveStringMap(options.asJava)
+
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
provider.getTable(dsOptions) match {
- case table: SupportsBatchWrite =>
- if (mode == SaveMode.Append) {
- val relation = DataSourceV2Relation.create(table, options)
+ // TODO (SPARK-27815): To not break existing tests, here we treat file source as a special
+ // case, and pass the save mode to file source directly. This hack should be removed.
+ case table: FileTable =>
+ val write = table.newWriteBuilder(dsOptions).asInstanceOf[FileWriteBuilder]
+ .mode(modeForDSV1) // should not change default mode for file source.
+ .withQueryId(UUID.randomUUID().toString)
+ .withInputDataSchema(df.logicalPlan.schema)
+ .buildForBatch()
+ // The returned `Write` can be null, which indicates that we can skip writing.
+ if (write != null) {
runCommand(df.sparkSession, "save") {
- AppendData.byName(relation, df.logicalPlan)
- }
- } else {
- val writeBuilder = table.newWriteBuilder(dsOptions)
- .withQueryId(UUID.randomUUID().toString)
- .withInputDataSchema(df.logicalPlan.schema)
- writeBuilder match {
- case s: SupportsSaveMode =>
- val write = s.mode(mode).buildForBatch()
- // It can only return null with `SupportsSaveMode`. We can clean it up after
- // removing `SupportsSaveMode`.
- if (write != null) {
- runCommand(df.sparkSession, "save") {
- WriteToDataSourceV2(write, df.logicalPlan)
- }
- }
-
- case _ => throw new AnalysisException(
- s"data source ${table.name} does not support SaveMode $mode")
+ WriteToDataSourceV2(write, df.logicalPlan)
}
}
+ case table: SupportsWrite if table.supports(BATCH_WRITE) =>
+ lazy val relation = DataSourceV2Relation.create(table, dsOptions)
+ modeForDSV2 match {
+ case SaveMode.Append =>
+ runCommand(df.sparkSession, "save") {
+ AppendData.byName(relation, df.logicalPlan)
+ }
+
+ case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) =>
+ // truncate the table
+ runCommand(df.sparkSession, "save") {
+ OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true))
+ }
+
+ case other =>
+ throw new AnalysisException(s"TableProvider implementation $source cannot be " +
+ s"written with $other mode, please use Append or Overwrite " +
+ "modes instead.")
+ }
+
// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
@@ -306,7 +324,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
- options = extraOptions.toMap).planForWriting(mode, df.logicalPlan)
+ options = extraOptions.toMap).planForWriting(modeForDSV1, df.logicalPlan)
}
}
@@ -355,7 +373,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
query = df.logicalPlan,
- overwrite = mode == SaveMode.Overwrite,
+ overwrite = modeForDSV1 == SaveMode.Overwrite,
ifPartitionNotExists = false)
}
}
@@ -435,7 +453,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val tableIdentWithDB = tableIdent.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString
- (tableExists, mode) match {
+ (tableExists, modeForDSV1) match {
case (true, SaveMode.Ignore) =>
// Do nothing
@@ -490,7 +508,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec)
- runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan)))
+ runCommand(df.sparkSession, "saveAsTable")(
+ CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan)))
}
/**
@@ -696,13 +715,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd)
}
+ private def modeForDSV1 = mode.getOrElse(SaveMode.ErrorIfExists)
+
+ private def modeForDSV2 = mode.getOrElse(SaveMode.Append)
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName
- private var mode: SaveMode = SaveMode.ErrorIfExists
+ private var mode: Option[SaveMode] = None
private val extraOptions = new scala.collection.mutable.HashMap[String, String]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index c5d14dfffd9b2..ff5ca2ac1111a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -21,6 +21,7 @@ import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -31,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalog.Catalog
+import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Catalogs}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.encoders._
@@ -619,6 +621,12 @@ class SparkSession private(
*/
@transient lazy val catalog: Catalog = new CatalogImpl(self)
+ @transient private lazy val catalogs = new mutable.HashMap[String, CatalogPlugin]()
+
+ private[sql] def catalog(name: String): CatalogPlugin = synchronized {
+ catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf))
+ }
+
/**
* Returns the specified table/view as a `DataFrame`.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
index c90b254a6d121..41cebc247a186 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
@@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand}
+import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -35,7 +35,7 @@ object HiveResult {
* `SparkSQLDriver` for CLI applications.
*/
def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match {
- case ExecutedCommandExec(desc: DescribeTableCommand) =>
+ case ExecutedCommandExec(_: DescribeCommandBase) =>
// If it is a describe command for a Hive table, we want to have the output format
// be similar with Hive.
executedPlan.executeCollectPublic().map {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 8deb55b00a9d3..ac61661e83e32 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -370,127 +370,72 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
}
/**
- * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal).
+ * Create a [[DescribeQueryCommand]] logical command.
*/
- type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean)
+ override def visitDescribeQuery(ctx: DescribeQueryContext): LogicalPlan = withOrigin(ctx) {
+ DescribeQueryCommand(visitQueryToDesc(ctx.queryToDesc()))
+ }
/**
- * Validate a create table statement and return the [[TableIdentifier]].
- */
- override def visitCreateTableHeader(
- ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) {
- val temporary = ctx.TEMPORARY != null
- val ifNotExists = ctx.EXISTS != null
- if (temporary && ifNotExists) {
- operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx)
+ * Converts a multi-part identifier to a TableIdentifier.
+ *
+ * If the multi-part identifier has too many parts, this will throw a ParseException.
+ */
+ def tableIdentifier(
+ multipart: Seq[String],
+ command: String,
+ ctx: ParserRuleContext): TableIdentifier = {
+ multipart match {
+ case Seq(tableName) =>
+ TableIdentifier(tableName)
+ case Seq(database, tableName) =>
+ TableIdentifier(tableName, Some(database))
+ case _ =>
+ operationNotAllowed(s"$command does not support multi-part identifiers", ctx)
}
- (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null)
}
/**
* Create a table, returning a [[CreateTable]] logical plan.
*
- * Expected format:
- * {{{
- * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name
- * USING table_provider
- * create_table_clauses
- * [[AS] select_statement];
+ * This is used to produce CreateTempViewUsing from CREATE TEMPORARY TABLE.
*
- * create_table_clauses (order insensitive):
- * [OPTIONS table_property_list]
- * [PARTITIONED BY (col_name, col_name, ...)]
- * [CLUSTERED BY (col_name, col_name, ...)
- * [SORTED BY (col_name [ASC|DESC], ...)]
- * INTO num_buckets BUCKETS
- * ]
- * [LOCATION path]
- * [COMMENT table_comment]
- * [TBLPROPERTIES (property_name=property_value, ...)]
- * }}}
+ * TODO: Remove this. It is used because CreateTempViewUsing is not a Catalyst plan.
+ * Either move CreateTempViewUsing into catalyst as a parsed logical plan, or remove it because
+ * it is deprecated.
*/
override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) {
- val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
- if (external) {
- operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx)
- }
+ val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
- checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
- checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
- checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
- checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
- checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
- checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
-
- val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
- val provider = ctx.tableProvider.qualifiedName.getText
- val schema = Option(ctx.colTypeList()).map(createSchema)
- val partitionColumnNames =
- Option(ctx.partitionColumnNames)
- .map(visitIdentifierList(_).toArray)
- .getOrElse(Array.empty[String])
- val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty)
- val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec)
-
- val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec)
- val storage = DataSource.buildStorageFormatFromOptions(options)
-
- if (location.isDefined && storage.locationUri.isDefined) {
- throw new ParseException(
- "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " +
- "you can only specify one of them.", ctx)
- }
- val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI))
-
- val tableType = if (customLocation.isDefined) {
- CatalogTableType.EXTERNAL
+ if (!temp || ctx.query != null) {
+ super.visitCreateTable(ctx)
} else {
- CatalogTableType.MANAGED
- }
-
- val tableDesc = CatalogTable(
- identifier = table,
- tableType = tableType,
- storage = storage.copy(locationUri = customLocation),
- schema = schema.getOrElse(new StructType),
- provider = Some(provider),
- partitionColumnNames = partitionColumnNames,
- bucketSpec = bucketSpec,
- properties = properties,
- comment = Option(ctx.comment).map(string))
-
- // Determine the storage mode.
- val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists
-
- if (ctx.query != null) {
- // Get the backing query.
- val query = plan(ctx.query)
-
- if (temp) {
- operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx)
+ if (external) {
+ operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx)
}
- // Don't allow explicit specification of schema for CTAS
- if (schema.nonEmpty) {
- operationNotAllowed(
- "Schema may not be specified in a Create Table As Select (CTAS) statement",
- ctx)
- }
- CreateTable(tableDesc, mode, Some(query))
- } else {
- if (temp) {
- if (ifNotExists) {
- operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx)
- }
+ checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx)
+ checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx)
+ checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx)
+ checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx)
+ checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx)
+ checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx)
- logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " +
- "CREATE TEMPORARY VIEW ... USING ... instead")
+ if (ifNotExists) {
// Unlike CREATE TEMPORARY VIEW USING, CREATE TEMPORARY TABLE USING does not support
// IF NOT EXISTS. Users are not allowed to replace the existing temp table.
- CreateTempViewUsing(table, schema, replace = false, global = false, provider, options)
- } else {
- CreateTable(tableDesc, mode, None)
+ operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx)
}
+
+ val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty)
+ val provider = ctx.tableProvider.qualifiedName.getText
+ val schema = Option(ctx.colTypeList()).map(createSchema)
+
+ logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " +
+ "CREATE TEMPORARY VIEW ... USING ... instead")
+
+ val table = tableIdentifier(ident, "CREATE TEMPORARY VIEW", ctx)
+ CreateTempViewUsing(table, schema, replace = false, global = false, provider, options)
}
}
@@ -555,77 +500,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
"MSCK REPAIR TABLE")
}
- /**
- * Convert a table property list into a key-value map.
- * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]].
- */
- override def visitTablePropertyList(
- ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
- val properties = ctx.tableProperty.asScala.map { property =>
- val key = visitTablePropertyKey(property.key)
- val value = visitTablePropertyValue(property.value)
- key -> value
- }
- // Check for duplicate property names.
- checkDuplicateKeys(properties, ctx)
- properties.toMap
- }
-
- /**
- * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified.
- */
- private def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = {
- val props = visitTablePropertyList(ctx)
- val badKeys = props.collect { case (key, null) => key }
- if (badKeys.nonEmpty) {
- operationNotAllowed(
- s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
- }
- props
- }
-
- /**
- * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified.
- */
- private def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = {
- val props = visitTablePropertyList(ctx)
- val badKeys = props.filter { case (_, v) => v != null }.keys
- if (badKeys.nonEmpty) {
- operationNotAllowed(
- s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx)
- }
- props.keys.toSeq
- }
-
- /**
- * A table property key can either be String or a collection of dot separated elements. This
- * function extracts the property key based on whether its a string literal or a table property
- * identifier.
- */
- override def visitTablePropertyKey(key: TablePropertyKeyContext): String = {
- if (key.STRING != null) {
- string(key.STRING)
- } else {
- key.getText
- }
- }
-
- /**
- * A table property value can be String, Integer, Boolean or Decimal. This function extracts
- * the property value based on whether its a string, integer, boolean or decimal literal.
- */
- override def visitTablePropertyValue(value: TablePropertyValueContext): String = {
- if (value == null) {
- null
- } else if (value.STRING != null) {
- string(value.STRING)
- } else if (value.booleanValue != null) {
- value.getText.toLowerCase(Locale.ROOT)
- } else {
- value.getText
- }
- }
-
/**
* Create a [[CreateDatabaseCommand]] command.
*
@@ -772,17 +646,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
ctx.TEMPORARY != null)
}
- /**
- * Create a [[DropTableCommand]] command.
- */
- override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) {
- DropTableCommand(
- visitTableIdentifier(ctx.tableIdentifier),
- ctx.EXISTS != null,
- ctx.VIEW != null,
- ctx.PURGE != null)
- }
-
/**
* Create a [[AlterTableRenameCommand]] command.
*
@@ -999,34 +862,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
newColumn = visitColType(ctx.colType))
}
- /**
- * Create location string.
- */
- override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
- string(ctx.STRING)
- }
-
- /**
- * Create a [[BucketSpec]].
- */
- override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
- BucketSpec(
- ctx.INTEGER_VALUE.getText.toInt,
- visitIdentifierList(ctx.identifierList),
- Option(ctx.orderedIdentifierList)
- .toSeq
- .flatMap(_.orderedIdentifier.asScala)
- .map { orderedIdCtx =>
- Option(orderedIdCtx.ordering).map(_.getText).foreach { dir =>
- if (dir.toLowerCase(Locale.ROOT) != "asc") {
- operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx)
- }
- }
-
- orderedIdCtx.identifier.getText
- })
- }
-
/**
* Convert a nested constants list into a sequence of string sequences.
*/
@@ -1122,7 +957,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
* }}}
*/
override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) {
- val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+ val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
// TODO: implement temporary tables
if (temp) {
throw new ParseException(
@@ -1180,6 +1015,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
CatalogTableType.MANAGED
}
+ val name = tableIdentifier(ident, "CREATE TABLE ... STORED AS ...", ctx)
+
// TODO support the sql text - have a proper location for this!
val tableDesc = CatalogTable(
identifier = name,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index edfa70403ad15..e72ddf13f1668 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2
+import org.apache.spark.sql.execution.streaming.sources.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery}
import org.apache.spark.sql.types.StructType
@@ -557,9 +557,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: RunnableCommand => ExecutedCommandExec(r) :: Nil
case MemoryPlan(sink, output) =>
- val encoder = RowEncoder(sink.schema)
- LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
- case MemoryPlanV2(sink, output) =>
val encoder = RowEncoder(StructType.fromAttributes(output))
LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 2bf6a58b55658..4b692aaeb1e63 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.{ByteBufferOutputStream, Utils}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 8dd484af6e908..6147d6fefd52a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -25,6 +25,7 @@ import org.apache.arrow.vector.complex._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
object ArrowWriter {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index d24e66e583857..8b70e336c14bb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -29,12 +29,12 @@ import org.apache.hadoop.fs.{FileContext, FsConstants, Path}
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.CatalogTableType._
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.Histogram
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier}
import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils}
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -494,6 +494,34 @@ case class TruncateTableCommand(
}
}
+abstract class DescribeCommandBase extends RunnableCommand {
+ override val output: Seq[Attribute] = Seq(
+ // Column names are based on Hive.
+ AttributeReference("col_name", StringType, nullable = false,
+ new MetadataBuilder().putString("comment", "name of the column").build())(),
+ AttributeReference("data_type", StringType, nullable = false,
+ new MetadataBuilder().putString("comment", "data type of the column").build())(),
+ AttributeReference("comment", StringType, nullable = true,
+ new MetadataBuilder().putString("comment", "comment of the column").build())()
+ )
+
+ protected def describeSchema(
+ schema: StructType,
+ buffer: ArrayBuffer[Row],
+ header: Boolean): Unit = {
+ if (header) {
+ append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
+ }
+ schema.foreach { column =>
+ append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull)
+ }
+ }
+
+ protected def append(
+ buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = {
+ buffer += Row(column, dataType, comment)
+ }
+}
/**
* Command that looks like
* {{{
@@ -504,17 +532,7 @@ case class DescribeTableCommand(
table: TableIdentifier,
partitionSpec: TablePartitionSpec,
isExtended: Boolean)
- extends RunnableCommand {
-
- override val output: Seq[Attribute] = Seq(
- // Column names are based on Hive.
- AttributeReference("col_name", StringType, nullable = false,
- new MetadataBuilder().putString("comment", "name of the column").build())(),
- AttributeReference("data_type", StringType, nullable = false,
- new MetadataBuilder().putString("comment", "data type of the column").build())(),
- AttributeReference("comment", StringType, nullable = true,
- new MetadataBuilder().putString("comment", "comment of the column").build())()
- )
+ extends DescribeCommandBase {
override def run(sparkSession: SparkSession): Seq[Row] = {
val result = new ArrayBuffer[Row]
@@ -603,22 +621,31 @@ case class DescribeTableCommand(
}
table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, ""))
}
+}
- private def describeSchema(
- schema: StructType,
- buffer: ArrayBuffer[Row],
- header: Boolean): Unit = {
- if (header) {
- append(buffer, s"# ${output.head.name}", output(1).name, output(2).name)
- }
- schema.foreach { column =>
- append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull)
- }
- }
+/**
+ * Command that looks like
+ * {{{
+ * DESCRIBE [QUERY] statement
+ * }}}
+ *
+ * Parameter 'statement' can be one of the following types :
+ * 1. SELECT statements
+ * 2. SELECT statements inside set operators (UNION, INTERSECT etc)
+ * 3. VALUES statement.
+ * 4. TABLE statement. Example : TABLE table_name
+ * 5. statements of the form 'FROM table SELECT *'
+ *
+ * TODO : support CTEs.
+ */
+case class DescribeQueryCommand(query: LogicalPlan)
+ extends DescribeCommandBase {
- private def append(
- buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = {
- buffer += Row(column, dataType, comment)
+ override def run(sparkSession: SparkSession): Seq[Row] = {
+ val result = new ArrayBuffer[Row]
+ val queryExecution = sparkSession.sessionState.executePlan(query)
+ describeSchema(queryExecution.analyzed.schema, result, header = false)
+ result
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index b0548bc21156e..622ad3b559ebd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -348,7 +348,8 @@ case class DataSource(
case (format: FileFormat, _)
if FileStreamSink.hasMetadata(
caseInsensitiveOptions.get("path").toSeq ++ paths,
- sparkSession.sessionState.newHadoopConf()) =>
+ sparkSession.sessionState.newHadoopConf(),
+ sparkSession.sessionState.conf) =>
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema)
val dataSchema = userSpecifiedSchema.orElse {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala
new file mode 100644
index 0000000000000..19881f69f158c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala
@@ -0,0 +1,236 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.util.Locale
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.{AnalysisException, SaveMode}
+import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.CastSupport
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils}
+import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.command.DropTableCommand
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.TableProvider
+import org.apache.spark.sql.types.StructType
+
+case class DataSourceResolution(
+ conf: SQLConf,
+ findCatalog: String => CatalogPlugin)
+ extends Rule[LogicalPlan] with CastSupport with LookupCatalog {
+
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+
+ override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name)
+
+ def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog)
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case CreateTableStatement(
+ AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties,
+ V1WriteProvider(provider), options, location, comment, ifNotExists) =>
+
+ val tableDesc = buildCatalogTable(table, schema, partitionCols, bucketSpec, properties,
+ provider, options, location, comment, ifNotExists)
+ val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists
+
+ CreateTable(tableDesc, mode, None)
+
+ case create: CreateTableStatement =>
+ // the provider was not a v1 source, convert to a v2 plan
+ val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName
+ val catalog = maybeCatalog.orElse(defaultCatalog)
+ .getOrElse(throw new AnalysisException(
+ s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
+ .asTableCatalog
+ convertCreateTable(catalog, identifier, create)
+
+ case CreateTableAsSelectStatement(
+ AsTableIdentifier(table), query, partitionCols, bucketSpec, properties,
+ V1WriteProvider(provider), options, location, comment, ifNotExists) =>
+
+ val tableDesc = buildCatalogTable(table, new StructType, partitionCols, bucketSpec,
+ properties, provider, options, location, comment, ifNotExists)
+ val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists
+
+ CreateTable(tableDesc, mode, Some(query))
+
+ case create: CreateTableAsSelectStatement =>
+ // the provider was not a v1 source, convert to a v2 plan
+ val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName
+ val catalog = maybeCatalog.orElse(defaultCatalog)
+ .getOrElse(throw new AnalysisException(
+ s"No catalog specified for table ${identifier.quoted} and no default catalog is set"))
+ .asTableCatalog
+ convertCTAS(catalog, identifier, create)
+
+ case DropTableStatement(CatalogObjectIdentifier(Some(catalog), ident), ifExists, _) =>
+ DropTable(catalog.asTableCatalog, ident, ifExists)
+
+ case DropTableStatement(AsTableIdentifier(tableName), ifExists, purge) =>
+ DropTableCommand(tableName, ifExists, isView = false, purge)
+
+ case DropViewStatement(CatalogObjectIdentifier(Some(catalog), ident), _) =>
+ throw new AnalysisException(
+ s"Can not specify catalog `${catalog.name}` for view $ident " +
+ s"because view support in catalog has not been implemented yet")
+
+ case DropViewStatement(AsTableIdentifier(tableName), ifExists) =>
+ DropTableCommand(tableName, ifExists, isView = true, purge = false)
+ }
+
+ object V1WriteProvider {
+ private val v1WriteOverrideSet =
+ conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",").toSet
+
+ def unapply(provider: String): Option[String] = {
+ if (v1WriteOverrideSet.contains(provider.toLowerCase(Locale.ROOT))) {
+ Some(provider)
+ } else {
+ lazy val providerClass = DataSource.lookupDataSource(provider, conf)
+ provider match {
+ case _ if classOf[TableProvider].isAssignableFrom(providerClass) =>
+ None
+ case _ =>
+ Some(provider)
+ }
+ }
+ }
+ }
+
+ private def buildCatalogTable(
+ table: TableIdentifier,
+ schema: StructType,
+ partitioning: Seq[Transform],
+ bucketSpec: Option[BucketSpec],
+ properties: Map[String, String],
+ provider: String,
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ ifNotExists: Boolean): CatalogTable = {
+
+ val storage = DataSource.buildStorageFormatFromOptions(options)
+ if (location.isDefined && storage.locationUri.isDefined) {
+ throw new AnalysisException(
+ "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " +
+ "you can only specify one of them.")
+ }
+ val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI))
+
+ val tableType = if (customLocation.isDefined) {
+ CatalogTableType.EXTERNAL
+ } else {
+ CatalogTableType.MANAGED
+ }
+
+ CatalogTable(
+ identifier = table,
+ tableType = tableType,
+ storage = storage.copy(locationUri = customLocation),
+ schema = schema,
+ provider = Some(provider),
+ partitionColumnNames = partitioning.asPartitionColumns,
+ bucketSpec = bucketSpec,
+ properties = properties,
+ comment = comment)
+ }
+
+ private def convertCTAS(
+ catalog: TableCatalog,
+ identifier: Identifier,
+ ctas: CreateTableAsSelectStatement): CreateTableAsSelect = {
+ // convert the bucket spec and add it as a transform
+ val partitioning = ctas.partitioning ++ ctas.bucketSpec.map(_.asTransform)
+ val properties = convertTableProperties(
+ ctas.properties, ctas.options, ctas.location, ctas.comment, ctas.provider)
+
+ CreateTableAsSelect(
+ catalog,
+ identifier,
+ partitioning,
+ ctas.asSelect,
+ properties,
+ writeOptions = ctas.options.filterKeys(_ != "path"),
+ ignoreIfExists = ctas.ifNotExists)
+ }
+
+ private def convertCreateTable(
+ catalog: TableCatalog,
+ identifier: Identifier,
+ create: CreateTableStatement): CreateV2Table = {
+ // convert the bucket spec and add it as a transform
+ val partitioning = create.partitioning ++ create.bucketSpec.map(_.asTransform)
+ val properties = convertTableProperties(
+ create.properties, create.options, create.location, create.comment, create.provider)
+
+ CreateV2Table(
+ catalog,
+ identifier,
+ create.tableSchema,
+ partitioning,
+ properties,
+ ignoreIfExists = create.ifNotExists)
+ }
+
+ private def convertTableProperties(
+ properties: Map[String, String],
+ options: Map[String, String],
+ location: Option[String],
+ comment: Option[String],
+ provider: String): Map[String, String] = {
+ if (options.contains("path") && location.isDefined) {
+ throw new AnalysisException(
+ "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " +
+ "you can only specify one of them.")
+ }
+
+ if ((options.contains("comment") || properties.contains("comment"))
+ && comment.isDefined) {
+ throw new AnalysisException(
+ "COMMENT and option/property 'comment' are both used to set the table comment, you can " +
+ "only specify one of them.")
+ }
+
+ if (options.contains("provider") || properties.contains("provider")) {
+ throw new AnalysisException(
+ "USING and option/property 'provider' are both used to set the provider implementation, " +
+ "you can only specify one of them.")
+ }
+
+ val filteredOptions = options.filterKeys(_ != "path")
+
+ // create table properties from TBLPROPERTIES and OPTIONS clauses
+ val tableProperties = new mutable.HashMap[String, String]()
+ tableProperties ++= properties
+ tableProperties ++= filteredOptions
+
+ // convert USING, LOCATION, and COMMENT clauses to table properties
+ tableProperties += ("provider" -> provider)
+ comment.map(text => tableProperties += ("comment" -> text))
+ location.orElse(options.get("path")).map(loc => tableProperties += ("location" -> loc))
+
+ tableProperties.toMap
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index b5cf8c9515bfb..b73dc30d6f23c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -426,6 +426,22 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
}
object DataSourceStrategy {
+ /**
+ * The attribute name of predicate could be different than the one in schema in case of
+ * case insensitive, we should change them to match the one in schema, so we do not need to
+ * worry about case sensitivity anymore.
+ */
+ protected[sql] def normalizeFilters(
+ filters: Seq[Expression],
+ attributes: Seq[AttributeReference]): Seq[Expression] = {
+ filters.filterNot(SubqueryExpression.hasSubquery).map { e =>
+ e transform {
+ case a: AttributeReference =>
+ a.withName(attributes.find(_.semanticEquals(a)).get.name)
+ }
+ }
+ }
+
/**
* Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
*
@@ -513,6 +529,12 @@ object DataSourceStrategy {
case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) =>
Some(sources.StringContains(a.name, v.toString))
+ case expressions.Literal(true, BooleanType) =>
+ Some(sources.AlwaysTrue)
+
+ case expressions.Literal(false, BooleanType) =>
+ Some(sources.AlwaysFalse)
+
case _ => None
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala
index 254c09001f7ec..7c72495548e3a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -33,10 +35,15 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
*/
class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) =>
+ case i @ InsertIntoTable(d @ DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) =>
val v1FileFormat = new OrcFileFormat
- val relation = HadoopFsRelation(table.getFileIndex, table.getFileIndex.partitionSchema,
- table.schema(), None, v1FileFormat, d.options)(sparkSession)
+ val relation = HadoopFsRelation(
+ table.fileIndex,
+ table.fileIndex.partitionSchema,
+ table.schema(),
+ None,
+ v1FileFormat,
+ d.options.asScala.toMap)(sparkSession)
i.copy(table = LogicalRelation(relation))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 62ab5c80d47cf..970cbda6355e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -147,15 +147,7 @@ object FileSourceStrategy extends Strategy with Logging {
// - filters that need to be evaluated again after the scan
val filterSet = ExpressionSet(filters)
- // The attribute name of predicate could be different than the one in schema in case of
- // case insensitive, we should change them to match the one in schema, so we do not need to
- // worry about case sensitivity anymore.
- val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e =>
- e transform {
- case a: AttributeReference =>
- a.withName(l.output.find(_.semanticEquals(a)).get.name)
- }
- }
+ val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, l.output)
val partitionColumns =
l.resolve(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
index 452ebbbeb99c8..2d90fd594fa7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala
@@ -17,43 +17,42 @@
package org.apache.spark.sql.execution.datasources.noop
-import org.apache.spark.sql.SaveMode
+import java.util
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* This is no-op datasource. It does not do anything besides consuming its input.
* This can be useful for benchmarking or to cache data without any additional overhead.
*/
-class NoopDataSource
- extends DataSourceV2
- with TableProvider
- with DataSourceRegister
- with StreamingWriteSupportProvider {
-
+class NoopDataSource extends TableProvider with DataSourceRegister {
override def shortName(): String = "noop"
- override def getTable(options: DataSourceOptions): Table = NoopTable
- override def createStreamingWriteSupport(
- queryId: String,
- schema: StructType,
- mode: OutputMode,
- options: DataSourceOptions): StreamingWriteSupport = NoopStreamingWriteSupport
+ override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable
}
-private[noop] object NoopTable extends Table with SupportsBatchWrite {
- override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder
+private[noop] object NoopTable extends Table with SupportsWrite {
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder
override def name(): String = "noop-table"
override def schema(): StructType = new StructType()
+ override def capabilities(): util.Set[TableCapability] = Set(
+ TableCapability.BATCH_WRITE,
+ TableCapability.TRUNCATE,
+ TableCapability.ACCEPT_ANY_SCHEMA,
+ TableCapability.STREAMING_WRITE).asJava
}
-private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsSaveMode {
+private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate {
+ override def truncate(): WriteBuilder = this
override def buildForBatch(): BatchWrite = NoopBatchWrite
- override def mode(mode: SaveMode): WriteBuilder = this
+ override def buildForStreaming(): StreamingWrite = NoopStreamingWrite
}
private[noop] object NoopBatchWrite extends BatchWrite {
@@ -72,7 +71,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] {
override def abort(): Unit = {}
}
-private[noop] object NoopStreamingWriteSupport extends StreamingWriteSupport {
+private[noop] object NoopStreamingWrite extends StreamingWrite {
override def createStreamingWriterFactory(): StreamingDataWriterFactory =
NoopStreamingDataWriterFactory
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
@@ -85,4 +84,3 @@ private[noop] object NoopStreamingDataWriterFactory extends StreamingDataWriterF
taskId: Long,
epochId: Long): DataWriter[InternalRow] = NoopWriter
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala
new file mode 100644
index 0000000000000..f35758bf08c67
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.LeafExecNode
+import org.apache.spark.sql.types.StructType
+
+case class CreateTableExec(
+ catalog: TableCatalog,
+ identifier: Identifier,
+ tableSchema: StructType,
+ partitioning: Seq[Transform],
+ tableProperties: Map[String, String],
+ ignoreIfExists: Boolean) extends LeafExecNode {
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ if (!catalog.tableExists(identifier)) {
+ try {
+ catalog.createTable(identifier, tableSchema, partitioning.toArray, tableProperties.asJava)
+ } catch {
+ case _: TableAlreadyExistsException if ignoreIfExists =>
+ logWarning(s"Table ${identifier.quoted} was created concurrently. Ignoring.")
+ }
+ } else if (!ignoreIfExists) {
+ throw new TableAlreadyExistsException(identifier)
+ }
+
+ sqlContext.sparkContext.parallelize(Seq.empty, 1)
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
new file mode 100644
index 0000000000000..eed69cdc8cac6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability}
+
+object DataSourceV2Implicits {
+ implicit class TableHelper(table: Table) {
+ def asReadable: SupportsRead = {
+ table match {
+ case support: SupportsRead =>
+ support
+ case _ =>
+ throw new AnalysisException(s"Table does not support reads: ${table.name}")
+ }
+ }
+
+ def asWritable: SupportsWrite = {
+ table match {
+ case support: SupportsWrite =>
+ support
+ case _ =>
+ throw new AnalysisException(s"Table does not support writes: ${table.name}")
+ }
+ }
+
+ def supports(capability: TableCapability): Boolean = table.capabilities.contains(capability)
+
+ def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 47cf26dc9481e..fc919439d9224 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -17,20 +17,15 @@
package org.apache.spark.sql.execution.datasources.v2
-import java.util.UUID
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _}
import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A logical plan representing a data source v2 table.
@@ -42,29 +37,21 @@ import org.apache.spark.sql.types.StructType
case class DataSourceV2Relation(
table: Table,
output: Seq[AttributeReference],
- options: Map[String, String])
+ options: CaseInsensitiveStringMap)
extends LeafNode with MultiInstanceRelation with NamedRelation {
+ import DataSourceV2Implicits._
+
override def name: String = table.name()
+ override def skipSchemaResolution: Boolean = table.supports(TableCapability.ACCEPT_ANY_SCHEMA)
+
override def simpleString(maxFields: Int): String = {
s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name"
}
- def newScanBuilder(): ScanBuilder = table match {
- case s: SupportsBatchRead =>
- val dsOptions = new DataSourceOptions(options.asJava)
- s.newScanBuilder(dsOptions)
- case _ => throw new AnalysisException(s"Table is not readable: ${table.name()}")
- }
-
- def newWriteBuilder(schema: StructType): WriteBuilder = table match {
- case s: SupportsBatchWrite =>
- val dsOptions = new DataSourceOptions(options.asJava)
- s.newWriteBuilder(dsOptions)
- .withQueryId(UUID.randomUUID().toString)
- .withInputDataSchema(schema)
- case _ => throw new AnalysisException(s"Table is not writable: ${table.name()}")
+ def newScanBuilder(): ScanBuilder = {
+ table.asReadable.newScanBuilder(options)
}
override def computeStats(): Statistics = {
@@ -72,7 +59,7 @@ case class DataSourceV2Relation(
scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
- Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
+ DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
@@ -105,15 +92,32 @@ case class StreamingDataSourceV2Relation(
override def computeStats(): Statistics = scan match {
case r: SupportsReportStatistics =>
val statistics = r.estimateStatistics()
- Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
+ DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes)
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
}
object DataSourceV2Relation {
- def create(table: Table, options: Map[String, String]): DataSourceV2Relation = {
+ def create(table: Table, options: CaseInsensitiveStringMap): DataSourceV2Relation = {
val output = table.schema().toAttributes
DataSourceV2Relation(table, output, options)
}
+
+ /**
+ * This is used to transform data source v2 statistics to logical.Statistics.
+ */
+ def transformV2Stats(
+ v2Statistics: V2Statistics,
+ defaultRowCount: Option[BigInt],
+ defaultSizeInBytes: Long): Statistics = {
+ val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) {
+ Some(v2Statistics.numRows().getAsLong)
+ } else {
+ defaultRowCount
+ }
+ Statistics(
+ sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes),
+ rowCount = numRows)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 40ac5cf402987..9889fd6731565 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -17,20 +17,22 @@
package org.apache.spark.sql.execution.datasources.v2
+import scala.collection.JavaConverters._
import scala.collection.mutable
-import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy}
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
+import org.apache.spark.sql.{AnalysisException, Strategy}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
+import org.apache.spark.sql.sources
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
-import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
-object DataSourceV2Strategy extends Strategy {
+object DataSourceV2Strategy extends Strategy with PredicateHelper {
/**
* Pushes down filters to the data source reader
@@ -100,14 +102,22 @@ object DataSourceV2Strategy extends Strategy {
}
}
+ import DataSourceV2Implicits._
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
val scanBuilder = relation.newScanBuilder()
+
+ val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery)
+ val normalizedFilters = DataSourceStrategy.normalizeFilters(
+ withoutSubquery, relation.output)
+
// `pushedFilters` will be pushed down and evaluated in the underlying data sources.
// `postScanFilters` need to be evaluated after the scan.
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
- val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, filters)
+ val (pushedFilters, postScanFiltersWithoutSubquery) =
+ pushFilters(scanBuilder, normalizedFilters)
+ val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery
val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters)
logInfo(
s"""
@@ -142,15 +152,29 @@ object DataSourceV2Strategy extends Strategy {
case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
+ case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) =>
+ CreateTableExec(catalog, ident, schema, parts, props, ifNotExists) :: Nil
+
+ case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) =>
+ val writeOptions = new CaseInsensitiveStringMap(options.asJava)
+ CreateTableAsSelectExec(
+ catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil
+
case AppendData(r: DataSourceV2Relation, query, _) =>
- val writeBuilder = r.newWriteBuilder(query.schema)
- writeBuilder match {
- case s: SupportsSaveMode =>
- val write = s.mode(SaveMode.Append).buildForBatch()
- assert(write != null)
- WriteToDataSourceV2Exec(write, planLater(query)) :: Nil
- case _ => throw new AnalysisException(s"data source ${r.name} does not support SaveMode")
- }
+ AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil
+
+ case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) =>
+ // fail if any filter cannot be converted. correctness depends on removing all matching data.
+ val filters = splitConjunctivePredicates(deleteExpr).map {
+ filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse(
+ throw new AnalysisException(s"Cannot translate expression to source filter: $filter"))
+ }.toArray
+
+ OverwriteByExpressionExec(
+ r.table.asWritable, filters, r.options, planLater(query)) :: Nil
+
+ case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) =>
+ OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil
case WriteToContinuousDataSource(writer, query) =>
WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
@@ -167,6 +191,9 @@ object DataSourceV2Strategy extends Strategy {
Nil
}
+ case DropTable(catalog, ident, ifExists) =>
+ DropTableExec(catalog, ident, ifExists) :: Nil
+
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
deleted file mode 100644
index f11703c8a2773..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import org.apache.commons.lang3.StringUtils
-
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.DataSourceV2
-import org.apache.spark.util.Utils
-
-/**
- * A trait that can be used by data source v2 related query plans(both logical and physical), to
- * provide a string format of the data source information for explain.
- */
-trait DataSourceV2StringFormat {
-
- /**
- * The instance of this data source implementation. Note that we only consider its class in
- * equals/hashCode, not the instance itself.
- */
- def source: DataSourceV2
-
- /**
- * The output of the data source reader, w.r.t. column pruning.
- */
- def output: Seq[Attribute]
-
- /**
- * The options for this data source reader.
- */
- def options: Map[String, String]
-
- /**
- * The filters which have been pushed to the data source.
- */
- def pushedFilters: Seq[Expression]
-
- private def sourceName: String = source match {
- case registered: DataSourceRegister => registered.shortName()
- // source.getClass.getSimpleName can cause Malformed class name error,
- // call safer `Utils.getSimpleName` instead
- case _ => Utils.getSimpleName(source.getClass)
- }
-
- def metadataString(maxFields: Int): String = {
- val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)]
-
- if (pushedFilters.nonEmpty) {
- entries += "Filters" -> pushedFilters.mkString("[", ", ", "]")
- }
-
- // TODO: we should only display some standard options like path, table, etc.
- if (options.nonEmpty) {
- entries += "Options" -> Utils.redact(options).map {
- case (k, v) => s"$k=$v"
- }.mkString("[", ",", "]")
- }
-
- val outputStr = truncatedString(output, "[", ", ", "]", maxFields)
-
- val entriesStr = if (entries.nonEmpty) {
- truncatedString(entries.map {
- case (key, value) => key + ": " + StringUtils.abbreviate(value, 100)
- }, " (", ", ", ")", maxFields)
- } else {
- ""
- }
-
- s"$sourceName$outputStr$entriesStr"
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
index e9cc3991155c4..30897d86f8179 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -21,8 +21,7 @@ import java.util.regex.Pattern
import org.apache.spark.internal.Logging
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport}
+import org.apache.spark.sql.sources.v2.{SessionConfigSupport, TableProvider}
private[sql] object DataSourceV2Utils extends Logging {
@@ -34,34 +33,28 @@ private[sql] object DataSourceV2Utils extends Logging {
* `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will
* be transformed into `xxx -> yyy`.
*
- * @param ds a [[DataSourceV2]] object
+ * @param source a [[TableProvider]] object
* @param conf the session conf
* @return an immutable map that contains all the extracted and transformed k/v pairs.
*/
- def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match {
- case cs: SessionConfigSupport =>
- val keyPrefix = cs.keyPrefix()
- require(keyPrefix != null, "The data source config key prefix can't be null.")
-
- val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)")
-
- conf.getAllConfs.flatMap { case (key, value) =>
- val m = pattern.matcher(key)
- if (m.matches() && m.groupCount() > 0) {
- Seq((m.group(1), value))
- } else {
- Seq.empty
+ def extractSessionConfigs(source: TableProvider, conf: SQLConf): Map[String, String] = {
+ source match {
+ case cs: SessionConfigSupport =>
+ val keyPrefix = cs.keyPrefix()
+ require(keyPrefix != null, "The data source config key prefix can't be null.")
+
+ val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)")
+
+ conf.getAllConfs.flatMap { case (key, value) =>
+ val m = pattern.matcher(key)
+ if (m.matches() && m.groupCount() > 0) {
+ Seq((m.group(1), value))
+ } else {
+ Seq.empty
+ }
}
- }
-
- case _ => Map.empty
- }
- def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = {
- val name = ds match {
- case register: DataSourceRegister => register.shortName()
- case _ => ds.getClass.getName
+ case _ => Map.empty
}
- throw new UnsupportedOperationException(name + " source does not support user-specified schema")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala
new file mode 100644
index 0000000000000..d325e0205f9d8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.execution.LeafExecNode
+
+/**
+ * Physical plan node for dropping a table.
+ */
+case class DropTableExec(catalog: TableCatalog, ident: Identifier, ifExists: Boolean)
+ extends LeafExecNode {
+
+ override def doExecute(): RDD[InternalRow] = {
+ if (catalog.tableExists(ident)) {
+ catalog.dropTable(ident)
+ } else if (!ifExists) {
+ throw new NoSuchTableException(ident)
+ }
+
+ sqlContext.sparkContext.parallelize(Seq.empty, 1)
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
index a0c932cbb0e09..e9c7a1bb749db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala
@@ -16,13 +16,13 @@
*/
package org.apache.spark.sql.execution.datasources.v2
-import scala.collection.JavaConverters._
+import com.fasterxml.jackson.databind.ObjectMapper
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.DataSourceRegister
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, SupportsBatchRead, TableProvider}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.TableProvider
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A base interface for data source v2 implementations of the built-in file-based data sources.
@@ -39,16 +39,12 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister {
lazy val sparkSession = SparkSession.active
- def getFileIndex(
- options: DataSourceOptions,
- userSpecifiedSchema: Option[StructType]): PartitioningAwareFileIndex = {
- val filePaths = options.paths()
- val hadoopConf =
- sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap)
- val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf,
- checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist())
- val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
- new InMemoryFileIndex(sparkSession, rootPathsSpecified,
- options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache)
+ protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = {
+ val objectMapper = new ObjectMapper()
+ Option(map.get("paths")).map { pathStr =>
+ objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq
+ }.getOrElse {
+ Option(map.get("path")).toSeq
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 3615b15be6fd5..bdd6a48df20ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
abstract class FileScan(
sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex) extends Scan with Batch {
+ fileIndex: PartitioningAwareFileIndex,
+ readSchema: StructType) extends Scan with Batch {
/**
* Returns whether a file with `path` could be split or not.
*/
@@ -34,6 +35,22 @@ abstract class FileScan(
false
}
+ /**
+ * Returns whether this format supports the given [[DataType]] in write path.
+ * By default all data types are supported.
+ */
+ def supportsDataType(dataType: DataType): Boolean = true
+
+ /**
+ * The string that represents the format that this data source provider uses. This is
+ * overridden by children to provide a nice alias for the data source. For example:
+ *
+ * {{{
+ * override def formatName(): String = "ORC"
+ * }}}
+ */
+ def formatName: String
+
protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
@@ -57,5 +74,13 @@ abstract class FileScan(
partitions.toArray
}
- override def toBatch: Batch = this
+ override def toBatch: Batch = {
+ readSchema.foreach { field =>
+ if (!supportsDataType(field.dataType)) {
+ throw new AnalysisException(
+ s"$formatName data source does not support ${field.dataType.catalogString} data type.")
+ }
+ }
+ this
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
index 0dbef145f7326..9cf292782ffe0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
@@ -16,19 +16,41 @@
*/
package org.apache.spark.sql.execution.datasources.v2
+import java.util
+
+import scala.collection.JavaConverters._
+
import org.apache.hadoop.fs.FileStatus
import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table}
+import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
abstract class FileTable(
sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
userSpecifiedSchema: Option[StructType])
- extends Table with SupportsBatchRead with SupportsBatchWrite {
- def getFileIndex: PartitioningAwareFileIndex = this.fileIndex
+ extends Table with SupportsRead with SupportsWrite {
+
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+
+ lazy val fileIndex: PartitioningAwareFileIndex = {
+ val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
+ // Hadoop Configurations are case sensitive.
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
+ // This is an internal config so must be present.
+ val checkFilesExist = options.get("check_files_exist").toBoolean
+ val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf,
+ checkEmptyGlobPath = true, checkFilesExist = checkFilesExist)
+ val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
+ new InMemoryFileIndex(
+ sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache)
+ }
lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
inferSchema(fileIndex.allFiles())
@@ -43,6 +65,12 @@ abstract class FileTable(
fileIndex.partitionSchema, caseSensitive)._1
}
+ override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms
+
+ override def properties: util.Map[String, String] = options.asCaseSensitiveMap
+
+ override def capabilities: java.util.Set[TableCapability] = FileTable.CAPABILITIES
+
/**
* When possible, this method should return the schema of the given `files`. When the format
* does not support inference, or no valid files are given should return None. In these cases
@@ -50,3 +78,7 @@ abstract class FileTable(
*/
def inferSchema(files: Seq[FileStatus]): Option[StructType]
}
+
+object FileTable {
+ private val CAPABILITIES = Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
index ce9b52f29d7bd..5375d965d1eff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.execution.datasources.v2
+import java.io.IOException
import java.util.UUID
import scala.collection.JavaConverters._
@@ -32,13 +33,16 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder}
+import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
-abstract class FileWriteBuilder(options: DataSourceOptions)
- extends WriteBuilder with SupportsSaveMode {
+abstract class FileWriteBuilder(
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ _formatName: String,
+ supportsDataType: DataType => Boolean) extends WriteBuilder {
private var schema: StructType = _
private var queryId: String = _
private var mode: SaveMode = _
@@ -53,25 +57,25 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
this
}
- override def mode(mode: SaveMode): WriteBuilder = {
+ def mode(mode: SaveMode): WriteBuilder = {
this.mode = mode
this
}
override def buildForBatch(): BatchWrite = {
validateInputs()
- val pathName = options.paths().head
- val path = new Path(pathName)
+ val path = new Path(paths.head)
val sparkSession = SparkSession.active
- val optionsAsScala = options.asMap().asScala.toMap
- val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala)
+ val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
+ // Hadoop Configurations are case sensitive.
+ val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val job = getJobInstance(hadoopConf, path)
val committer = FileCommitProtocol.instantiate(
sparkSession.sessionState.conf.fileCommitProtocolClass,
jobId = java.util.UUID.randomUUID().toString,
- outputPath = pathName)
+ outputPath = paths.head)
lazy val description =
- createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala)
+ createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, options.asScala.toMap)
val fs = path.getFileSystem(hadoopConf)
mode match {
@@ -83,7 +87,9 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
null
case SaveMode.Overwrite =>
- committer.deleteWithJob(fs, path, true)
+ if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) {
+ throw new IOException(s"Unable to clear directory $path prior to writing to it")
+ }
committer.setupJob(job)
new FileBatchWrite(job, description, committer)
@@ -104,12 +110,35 @@ abstract class FileWriteBuilder(options: DataSourceOptions)
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory
+ /**
+ * Returns whether this format supports the given [[DataType]] in write path.
+ * By default all data types are supported.
+ */
+ def supportsDataType(dataType: DataType): Boolean = true
+
+ /**
+ * The string that represents the format that this data source provider uses. This is
+ * overridden by children to provide a nice alias for the data source. For example:
+ *
+ * {{{
+ * override def formatName(): String = "ORC"
+ * }}}
+ */
+ def formatName: String
+
private def validateInputs(): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
assert(mode != null, "Missing save mode")
- assert(options.paths().length == 1)
+ assert(paths.length == 1)
DataSource.validateSchema(schema)
+ schema.foreach { field =>
+ if (!supportsDataType.apply(field.dataType)) {
+ throw new AnalysisException(
+ s"$formatName data source does not support ${field.dataType.catalogString}" +
+ s" data type.")
+ }
+ }
}
private def getJobInstance(hadoopConf: Configuration, path: Path): Job = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala
new file mode 100644
index 0000000000000..c029acc0bb2df
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
+import org.apache.spark.sql.sources.v2.TableCapability.{CONTINUOUS_READ, MICRO_BATCH_READ}
+
+/**
+ * This rules adds some basic table capability check for streaming scan, without knowing the actual
+ * streaming execution mode.
+ */
+object V2StreamingScanSupportCheck extends (LogicalPlan => Unit) {
+ import DataSourceV2Implicits._
+
+ override def apply(plan: LogicalPlan): Unit = {
+ plan.foreach {
+ case r: StreamingRelationV2 if !r.table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>
+ throw new AnalysisException(
+ s"Table ${r.table.name()} does not support either micro-batch or continuous scan.")
+ case _ =>
+ }
+
+ val streamingSources = plan.collect {
+ case r: StreamingRelationV2 => r.table
+ }
+ val v1StreamingRelations = plan.collect {
+ case r: StreamingRelation => r
+ }
+
+ if (streamingSources.length + v1StreamingRelations.length > 1) {
+ val allSupportsMicroBatch = streamingSources.forall(_.supports(MICRO_BATCH_READ))
+ // v1 streaming data source only supports micro-batch.
+ val allSupportsContinuous = streamingSources.forall(_.supports(CONTINUOUS_READ)) &&
+ v1StreamingRelations.isEmpty
+ if (!allSupportsMicroBatch && !allSupportsContinuous) {
+ val microBatchSources =
+ streamingSources.filter(_.supports(MICRO_BATCH_READ)).map(_.name()) ++
+ v1StreamingRelations.map(_.sourceName)
+ val continuousSources = streamingSources.filter(_.supports(CONTINUOUS_READ)).map(_.name())
+ throw new AnalysisException(
+ "The streaming sources in a query do not have a common supported execution mode.\n" +
+ "Sources support micro-batch: " + microBatchSources.mkString(", ") + "\n" +
+ "Sources support continuous: " + continuousSources.mkString(", "))
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala
new file mode 100644
index 0000000000000..cf77998c122f8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic}
+import org.apache.spark.sql.sources.v2.TableCapability._
+import org.apache.spark.sql.types.BooleanType
+
+object V2WriteSupportCheck extends (LogicalPlan => Unit) {
+ import DataSourceV2Implicits._
+
+ def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
+
+ override def apply(plan: LogicalPlan): Unit = plan foreach {
+ case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) =>
+ failAnalysis(s"Table does not support append in batch mode: ${rel.table}")
+
+ case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _)
+ if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) =>
+ failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}")
+
+ case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) =>
+ expr match {
+ case Literal(true, BooleanType) =>
+ if (!rel.table.supports(BATCH_WRITE) ||
+ !rel.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) {
+ failAnalysis(
+ s"Table does not support truncate in batch mode: ${rel.table}")
+ }
+ case _ =>
+ if (!rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_BY_FILTER)) {
+ failAnalysis(s"Table does not support overwrite expression ${expr.sql} " +
+ s"in batch mode: ${rel.table}")
+ }
+ }
+
+ case _ => // OK
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 50c5e4f2ad7df..6c771ea988324 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -17,17 +17,26 @@
package org.apache.spark.sql.execution.datasources.v2
+import java.util.UUID
+
+import scala.collection.JavaConverters._
import scala.util.control.NonFatal
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
+import org.apache.spark.sql.sources.v2.SupportsWrite
+import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{LongAccumulator, Utils}
/**
@@ -42,17 +51,170 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan)
}
/**
- * The physical plan for writing data into data source v2.
+ * Physical plan node for v2 create table as select.
+ *
+ * A new table will be created using the schema of the query, and rows from the query are appended.
+ * If either table creation or the append fails, the table will be deleted. This implementation does
+ * not provide an atomic CTAS.
+ */
+case class CreateTableAsSelectExec(
+ catalog: TableCatalog,
+ ident: Identifier,
+ partitioning: Seq[Transform],
+ query: SparkPlan,
+ properties: Map[String, String],
+ writeOptions: CaseInsensitiveStringMap,
+ ifNotExists: Boolean) extends V2TableWriteExec {
+
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ if (catalog.tableExists(ident)) {
+ if (ifNotExists) {
+ return sparkContext.parallelize(Seq.empty, 1)
+ }
+
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ Utils.tryWithSafeFinallyAndFailureCallbacks({
+ catalog.createTable(ident, query.schema, partitioning.toArray, properties.asJava) match {
+ case table: SupportsWrite =>
+ val batchWrite = table.newWriteBuilder(writeOptions)
+ .withInputDataSchema(query.schema)
+ .withQueryId(UUID.randomUUID().toString)
+ .buildForBatch()
+
+ doWrite(batchWrite)
+
+ case _ =>
+ // table does not support writes
+ throw new SparkException(s"Table implementation does not support writes: ${ident.quoted}")
+ }
+
+ })(catchBlock = {
+ catalog.dropTable(ident)
+ })
+ }
+}
+
+/**
+ * Physical plan node for append into a v2 table.
+ *
+ * Rows in the output data set are appended.
*/
-case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan)
- extends UnaryExecNode {
+case class AppendDataExec(
+ table: SupportsWrite,
+ writeOptions: CaseInsensitiveStringMap,
+ query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val batchWrite = newWriteBuilder().buildForBatch()
+ doWrite(batchWrite)
+ }
+}
+
+/**
+ * Physical plan node for overwrite into a v2 table.
+ *
+ * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be
+ * deleted and rows in the output data set are appended.
+ *
+ * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to
+ * truncate the table -- delete all rows -- and append the output data set. This uses the filter
+ * AlwaysTrue to delete all rows.
+ */
+case class OverwriteByExpressionExec(
+ table: SupportsWrite,
+ deleteWhere: Array[Filter],
+ writeOptions: CaseInsensitiveStringMap,
+ query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
+
+ private def isTruncate(filters: Array[Filter]): Boolean = {
+ filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val batchWrite = newWriteBuilder() match {
+ case builder: SupportsTruncate if isTruncate(deleteWhere) =>
+ builder.truncate().buildForBatch()
+
+ case builder: SupportsOverwrite =>
+ builder.overwrite(deleteWhere).buildForBatch()
+
+ case _ =>
+ throw new SparkException(s"Table does not support overwrite by expression: $table")
+ }
+
+ doWrite(batchWrite)
+ }
+}
+
+/**
+ * Physical plan node for dynamic partition overwrite into a v2 table.
+ *
+ * Dynamic partition overwrite is the behavior of Hive INSERT OVERWRITE ... PARTITION queries, and
+ * Spark INSERT OVERWRITE queries when spark.sql.sources.partitionOverwriteMode=dynamic. Each
+ * partition in the output data set replaces the corresponding existing partition in the table or
+ * creates a new partition. Existing partitions for which there is no data in the output data set
+ * are not modified.
+ */
+case class OverwritePartitionsDynamicExec(
+ table: SupportsWrite,
+ writeOptions: CaseInsensitiveStringMap,
+ query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper {
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ val batchWrite = newWriteBuilder() match {
+ case builder: SupportsDynamicOverwrite =>
+ builder.overwriteDynamicPartitions().buildForBatch()
+
+ case _ =>
+ throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
+ }
+
+ doWrite(batchWrite)
+ }
+}
+
+case class WriteToDataSourceV2Exec(
+ batchWrite: BatchWrite,
+ query: SparkPlan) extends V2TableWriteExec {
+
+ def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty()
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ doWrite(batchWrite)
+ }
+}
+
+/**
+ * Helper for physical plans that build batch writes.
+ */
+trait BatchWriteHelper {
+ def table: SupportsWrite
+ def query: SparkPlan
+ def writeOptions: CaseInsensitiveStringMap
+
+ def newWriteBuilder(): WriteBuilder = {
+ table.newWriteBuilder(writeOptions)
+ .withInputDataSchema(query.schema)
+ .withQueryId(UUID.randomUUID().toString)
+ }
+}
+
+/**
+ * The base physical plan for writing data into data source v2.
+ */
+trait V2TableWriteExec extends UnaryExecNode {
+ def query: SparkPlan
var commitProgress: Option[StreamWriterCommitProgress] = None
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
- override protected def doExecute(): RDD[InternalRow] = {
+ protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = {
val writerFactory = batchWrite.createBatchWriterFactory()
val useCommitCoordinator = batchWrite.useCommitCoordinator
val rdd = query.execute()
@@ -169,8 +331,8 @@ object DataWritingSparkTask extends Logging {
}
private[v2] case class DataWritingSparkTaskResult(
- numRows: Long,
- writerCommitMessage: WriterCommitMessage)
+ numRows: Long,
+ writerCommitMessage: WriterCommitMessage)
/**
* Sink progress information collected after commit.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
index db1f2f7934221..900c94e937ffc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.v2._
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.sources.v2.Table
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
class OrcDataSourceV2 extends FileDataSourceV2 {
@@ -28,19 +29,36 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
override def shortName(): String = "orc"
- private def getTableName(options: DataSourceOptions): String = {
- shortName() + ":" + options.paths().mkString(";")
+ private def getTableName(paths: Seq[String]): String = {
+ shortName() + ":" + paths.mkString(";")
}
- override def getTable(options: DataSourceOptions): Table = {
- val tableName = getTableName(options)
- val fileIndex = getFileIndex(options, None)
- OrcTable(tableName, sparkSession, fileIndex, None)
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ val paths = getPaths(options)
+ val tableName = getTableName(paths)
+ OrcTable(tableName, sparkSession, options, paths, None)
}
- override def getTable(options: DataSourceOptions, schema: StructType): Table = {
- val tableName = getTableName(options)
- val fileIndex = getFileIndex(options, Some(schema))
- OrcTable(tableName, sparkSession, fileIndex, Some(schema))
+ override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
+ val paths = getPaths(options)
+ val tableName = getTableName(paths)
+ OrcTable(tableName, sparkSession, options, paths, Some(schema))
+ }
+}
+
+object OrcDataSourceV2 {
+ def supportsDataType(dataType: DataType): Boolean = dataType match {
+ case _: AtomicType => true
+
+ case st: StructType => st.forall { f => supportsDataType(f.dataType) }
+
+ case ArrayType(elementType, _) => supportsDataType(elementType)
+
+ case MapType(keyType, valueType, _) =>
+ supportsDataType(keyType) && supportsDataType(valueType)
+
+ case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
+
+ case _ => false
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index a792ad318b398..3c5dc1f50d7e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.SerializableConfiguration
case class OrcScan(
@@ -31,7 +31,7 @@ case class OrcScan(
hadoopConf: Configuration,
fileIndex: PartitioningAwareFileIndex,
dataSchema: StructType,
- readSchema: StructType) extends FileScan(sparkSession, fileIndex) {
+ readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) {
override def isSplitable(path: Path): Boolean = true
override def createReaderFactory(): PartitionReaderFactory = {
@@ -40,4 +40,10 @@ case class OrcScan(
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
}
+
+ override def supportsDataType(dataType: DataType): Boolean = {
+ OrcDataSourceV2.supportsDataType(dataType)
+ }
+
+ override def formatName: String = "ORC"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
index eb27bbd3abeaa..a2c55e8c43021 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
@@ -26,18 +26,21 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.Scan
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class OrcScanBuilder(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
schema: StructType,
dataSchema: StructType,
- options: DataSourceOptions) extends FileScanBuilder(schema) {
- lazy val hadoopConf =
- sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap)
+ options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) {
+ lazy val hadoopConf = {
+ val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
+ // Hadoop Configurations are case sensitive.
+ sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
+ }
override def build(): Scan = {
OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
index b467e505f1bac..aac38fb3fa1ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
@@ -19,25 +19,26 @@ package org.apache.spark.sql.execution.datasources.v2.orc
import org.apache.hadoop.fs.FileStatus
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.v2.FileTable
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class OrcTable(
name: String,
sparkSession: SparkSession,
- fileIndex: PartitioningAwareFileIndex,
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
userSpecifiedSchema: Option[StructType])
- extends FileTable(sparkSession, fileIndex, userSpecifiedSchema) {
- override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder =
+ extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): OrcScanBuilder =
new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options)
override def inferSchema(files: Seq[FileStatus]): Option[StructType] =
OrcUtils.readSchema(sparkSession, files)
- override def newWriteBuilder(options: DataSourceOptions): WriteBuilder =
- new OrcWriteBuilder(options)
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
+ new OrcWriteBuilder(options, paths)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
index 80429d91d5e4d..b1f8b8916a390 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
@@ -25,10 +25,16 @@ import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFac
import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
+ extends FileWriteBuilder(
+ options,
+ paths,
+ "orc",
+ supportsDataType = OrcDataSourceV2.supportsDataType) {
-class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(options) {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
@@ -63,4 +69,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio
}
}
}
+
+ override def supportsDataType(dataType: DataType): Boolean = {
+ OrcDataSourceV2.supportsDataType(dataType)
+ }
+
+ override def formatName: String = "ORC"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
index 2ab7240556aaa..0c78cca086ed3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index a5203daea9cd0..d1105f0382f6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
/**
* Grouped a iterator into batches.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 04623b1ab3c2f..3710218b2af5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -29,8 +29,9 @@ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter}
import org.apache.spark._
import org.apache.spark.api.python._
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.Utils
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index e9cff1a5a2007..c598b7c671a42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
/**
* Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index 1ce1215bfdd62..01ce07b133ffd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.execution.window._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index b3d12f67b5d63..b679f163fc561 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming
import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.internal.io.FileCommitProtocol
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.SerializableConfiguration
object FileStreamSink extends Logging {
@@ -37,23 +39,54 @@ object FileStreamSink extends Logging {
* Returns true if there is a single path that has a metadata log indicating which files should
* be read.
*/
- def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = {
+ def hasMetadata(path: Seq[String], hadoopConf: Configuration, sqlConf: SQLConf): Boolean = {
path match {
case Seq(singlePath) =>
+ val hdfsPath = new Path(singlePath)
+ val fs = hdfsPath.getFileSystem(hadoopConf)
+ if (fs.isDirectory(hdfsPath)) {
+ val metadataPath = new Path(hdfsPath, metadataDir)
+ checkEscapedMetadataPath(fs, metadataPath, sqlConf)
+ fs.exists(metadataPath)
+ } else {
+ false
+ }
+ case _ => false
+ }
+ }
+
+ def checkEscapedMetadataPath(fs: FileSystem, metadataPath: Path, sqlConf: SQLConf): Unit = {
+ if (sqlConf.getConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED)
+ && StreamExecution.containsSpecialCharsInPath(metadataPath)) {
+ val legacyMetadataPath = new Path(metadataPath.toUri.toString)
+ val legacyMetadataPathExists =
try {
- val hdfsPath = new Path(singlePath)
- val fs = hdfsPath.getFileSystem(hadoopConf)
- if (fs.isDirectory(hdfsPath)) {
- fs.exists(new Path(hdfsPath, metadataDir))
- } else {
- false
- }
+ fs.exists(legacyMetadataPath)
} catch {
case NonFatal(e) =>
- logWarning(s"Error while looking for metadata directory.")
+ // We may not have access to this directory. Don't fail the query if that happens.
+ logWarning(e.getMessage, e)
false
}
- case _ => false
+ if (legacyMetadataPathExists) {
+ throw new SparkException(
+ s"""Error: we detected a possible problem with the location of your "_spark_metadata"
+ |directory and you likely need to move it before restarting this query.
+ |
+ |Earlier version of Spark incorrectly escaped paths when writing out the
+ |"_spark_metadata" directory for structured streaming. While this was corrected in
+ |Spark 3.0, it appears that your query was started using an earlier version that
+ |incorrectly handled the "_spark_metadata" path.
+ |
+ |Correct "_spark_metadata" Directory: $metadataPath
+ |Incorrect "_spark_metadata" Directory: $legacyMetadataPath
+ |
+ |Please move the data from the incorrect directory to the correct one, delete the
+ |incorrect directory, and then restart this query. If you believe you are receiving
+ |this message in error, you can disable it with the SQL conf
+ |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}."""
+ .stripMargin)
+ }
}
}
@@ -92,11 +125,16 @@ class FileStreamSink(
partitionColumnNames: Seq[String],
options: Map[String, String]) extends Sink with Logging {
+ private val hadoopConf = sparkSession.sessionState.newHadoopConf()
private val basePath = new Path(path)
- private val logPath = new Path(basePath, FileStreamSink.metadataDir)
+ private val logPath = {
+ val metadataDir = new Path(basePath, FileStreamSink.metadataDir)
+ val fs = metadataDir.getFileSystem(hadoopConf)
+ FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf)
+ metadataDir
+ }
private val fileLog =
- new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString)
- private val hadoopConf = sparkSession.sessionState.newHadoopConf()
+ new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toString)
private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = {
val serializableHadoopConf = new SerializableConfiguration(hadoopConf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index 103fa7ce9066d..43b70ae0a51b1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -208,7 +208,7 @@ class FileStreamSource(
var allFiles: Seq[FileStatus] = null
sourceHasMetadata match {
case None =>
- if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) {
+ if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) {
sourceHasMetadata = Some(true)
allFiles = allFilesUsingMetadataLogFileIndex()
} else {
@@ -220,7 +220,7 @@ class FileStreamSource(
// double check whether source has metadata, preventing the extreme corner case that
// metadata log and data files are only generated after the previous
// `FileStreamSink.hasMetadata` check
- if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) {
+ if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) {
sourceHasMetadata = Some(true)
allFiles = allFilesUsingMetadataLogFileIndex()
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
index 3ff5b86ac45d6..a27898cb0c9fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
@@ -17,12 +17,10 @@
package org.apache.spark.sql.execution.streaming
-import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
-
/**
* A simple offset for sources that produce a single linear stream of data.
*/
-case class LongOffset(offset: Long) extends OffsetV2 {
+case class LongOffset(offset: Long) extends Offset {
override val json = offset.toString
@@ -37,14 +35,4 @@ object LongOffset {
* @return new LongOffset
*/
def apply(offset: SerializedOffset) : LongOffset = new LongOffset(offset.json.toLong)
-
- /**
- * Convert generic Offset to LongOffset if possible.
- * @return converted LongOffset
- */
- def convert(offset: Offset): Option[LongOffset] = offset match {
- case lo: LongOffset => Some(lo)
- case so: SerializedOffset => Some(LongOffset(so))
- case _ => None
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
index 5cacdd070b735..80eed7b277216 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala
@@ -39,10 +39,16 @@ class MetadataLogFileIndex(
userSpecifiedSchema: Option[StructType])
extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) {
- private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
+ private val metadataDirectory = {
+ val metadataDir = new Path(path, FileStreamSink.metadataDir)
+ val fs = metadataDir.getFileSystem(sparkSession.sessionState.newHadoopConf())
+ FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf)
+ metadataDir
+ }
+
logInfo(s"Reading streaming file log from $metadataDirectory")
private val metadataLog =
- new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString)
+ new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toString)
private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory)
private var cachedPartitionSpec: PartitionSpec = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 2c339759f95ba..7a3cdbc926446 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.streaming
-import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.sql.{Dataset, SparkSession}
@@ -26,11 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatch
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec}
-import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateControlMicroBatchStream}
+import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
+import org.apache.spark.sql.execution.streaming.sources.{RateControlMicroBatchStream, WriteToMicroBatchDataSource}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2._
-import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.Clock
@@ -39,7 +38,7 @@ class MicroBatchExecution(
name: String,
checkpointRoot: String,
analyzedPlan: LogicalPlan,
- sink: BaseStreamingSink,
+ sink: Table,
trigger: Trigger,
triggerClock: Clock,
outputMode: OutputMode,
@@ -49,7 +48,7 @@ class MicroBatchExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
- @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty
+ @volatile protected var sources: Seq[SparkDataStream] = Seq.empty
private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
@@ -78,6 +77,7 @@ class MicroBatchExecution(
val disabledSources =
sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",")
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val _logicalPlan = analyzedPlan.transform {
case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) =>
toExecutionRelationMap.getOrElseUpdate(streamingRelation, {
@@ -88,32 +88,33 @@ class MicroBatchExecution(
logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]")
StreamingExecutionRelation(source, output)(sparkSession)
})
- case s @ StreamingRelationV2(ds, dsName, table: SupportsMicroBatchRead, options, output, _)
- if !disabledSources.contains(ds.getClass.getCanonicalName) =>
- v2ToRelationMap.getOrElseUpdate(s, {
- // Materialize source to avoid creating it in every batch
- val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- nextSourceId += 1
- logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]")
- val dsOptions = new DataSourceOptions(options.asJava)
- // TODO: operator pushdown.
- val scan = table.newScanBuilder(dsOptions).build()
- val stream = scan.toMicroBatchStream(metadataPath)
- StreamingDataSourceV2Relation(output, scan, stream)
- })
- case s @ StreamingRelationV2(ds, dsName, _, _, output, v1Relation) =>
- v2ToExecutionRelationMap.getOrElseUpdate(s, {
- // Materialize source to avoid creating it in every batch
- val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
- if (v1Relation.isEmpty) {
- throw new UnsupportedOperationException(
- s"Data source $dsName does not support microbatch processing.")
- }
- val source = v1Relation.get.dataSource.createSource(metadataPath)
- nextSourceId += 1
- logInfo(s"Using Source [$source] from DataSourceV2 named '$dsName' [$ds]")
- StreamingExecutionRelation(source, output)(sparkSession)
- })
+
+ case s @ StreamingRelationV2(src, srcName, table: SupportsRead, options, output, v1) =>
+ val v2Disabled = disabledSources.contains(src.getClass.getCanonicalName)
+ if (!v2Disabled && table.supports(TableCapability.MICRO_BATCH_READ)) {
+ v2ToRelationMap.getOrElseUpdate(s, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ nextSourceId += 1
+ logInfo(s"Reading table [$table] from DataSourceV2 named '$srcName' [$src]")
+ // TODO: operator pushdown.
+ val scan = table.newScanBuilder(options).build()
+ val stream = scan.toMicroBatchStream(metadataPath)
+ StreamingDataSourceV2Relation(output, scan, stream)
+ })
+ } else if (v1.isEmpty) {
+ throw new UnsupportedOperationException(
+ s"Data source $srcName does not support microbatch processing.")
+ } else {
+ v2ToExecutionRelationMap.getOrElseUpdate(s, {
+ // Materialize source to avoid creating it in every batch
+ val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
+ val source = v1.get.dataSource.createSource(metadataPath)
+ nextSourceId += 1
+ logInfo(s"Using Source [$source] from DataSourceV2 named '$srcName' [$src]")
+ StreamingExecutionRelation(source, output)(sparkSession)
+ })
+ }
}
sources = _logicalPlan.collect {
// v1 source
@@ -122,7 +123,15 @@ class MicroBatchExecution(
case r: StreamingDataSourceV2Relation => r.stream
}
uniqueSources = sources.distinct
- _logicalPlan
+
+ // TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
+ sink match {
+ case s: SupportsWrite =>
+ val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan)
+ WriteToMicroBatchDataSource(streamingWrite, _logicalPlan)
+
+ case _ => _logicalPlan
+ }
}
/**
@@ -287,7 +296,7 @@ class MicroBatchExecution(
* batch will be executed before getOffset is called again. */
availableOffsets.foreach {
case (source: Source, end: Offset) =>
- val start = committedOffsets.get(source)
+ val start = committedOffsets.get(source).map(_.asInstanceOf[Offset])
source.getBatch(start, end)
case nonV1Tuple =>
// The V2 API does not have the same edge case requiring getBatch to be called
@@ -345,7 +354,7 @@ class MicroBatchExecution(
if (isCurrentBatchConstructed) return true
// Generate a map from each unique source to the next available offset.
- val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map {
+ val latestOffsets: Map[SparkDataStream, Option[OffsetV2]] = uniqueSources.map {
case s: Source =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
@@ -402,7 +411,7 @@ class MicroBatchExecution(
val prevBatchOff = offsetLog.get(currentBatchId - 1)
if (prevBatchOff.isDefined) {
prevBatchOff.get.toStreamProgress(sources).foreach {
- case (src: Source, off) => src.commit(off)
+ case (src: Source, off: Offset) => src.commit(off)
case (stream: MicroBatchStream, off) =>
stream.commit(stream.deserializeOffset(off.json))
case (src, _) =>
@@ -439,9 +448,9 @@ class MicroBatchExecution(
// Request unprocessed data from all sources.
newData = reportTimeTaken("getBatch") {
availableOffsets.flatMap {
- case (source: Source, available)
+ case (source: Source, available: Offset)
if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
- val current = committedOffsets.get(source)
+ val current = committedOffsets.get(source).map(_.asInstanceOf[Offset])
val batch = source.getBatch(current, available)
assert(batch.isStreaming,
s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
@@ -513,13 +522,8 @@ class MicroBatchExecution(
val triggerLogicalPlan = sink match {
case _: Sink => newAttributePlan
- case s: StreamingWriteSupportProvider =>
- val writer = s.createStreamingWriteSupport(
- s"$runId",
- newAttributePlan.schema,
- outputMode,
- new DataSourceOptions(extraOptions.asJava))
- WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, writer), newAttributePlan)
+ case _: SupportsWrite =>
+ newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].createPlan(currentBatchId)
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
}
@@ -549,7 +553,7 @@ class MicroBatchExecution(
SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
sink match {
case s: Sink => s.addBatch(currentBatchId, nextBatch)
- case _: StreamingWriteSupportProvider =>
+ case _: SupportsWrite =>
// This doesn't accumulate any data - it just forces execution of the microbatch writer.
nextBatch.collect()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
deleted file mode 100644
index 43ad4b3384ec3..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming;
-
-/**
- * This is an internal, deprecated interface. New source implementations should use the
- * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be
- * supported in the long term.
- *
- * This class will be removed in a future release.
- */
-public abstract class Offset {
- /**
- * A JSON-serialized representation of an Offset that is
- * used for saving offsets to the offset log.
- * Note: We assume that equivalent/equal offsets serialize to
- * identical JSON strings.
- *
- * @return JSON string encoding
- */
- public abstract String json();
-
- /**
- * Equality based on JSON string representation. We leverage the
- * JSON representation for normalization between the Offset's
- * in memory and on disk representations.
- */
- @Override
- public boolean equals(Object obj) {
- if (obj instanceof Offset) {
- return this.json().equals(((Offset) obj).json());
- } else {
- return false;
- }
- }
-
- @Override
- public int hashCode() {
- return this.json().hashCode();
- }
-
- @Override
- public String toString() {
- return this.json();
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
index 73cf355dbe758..b6fa2e9dc3612 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala
@@ -24,13 +24,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}
+
/**
* An ordered collection of offsets, used to track the progress of processing data from one or more
* [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance
* vector clock that must progress linearly forward.
*/
-case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) {
+case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) {
/**
* Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of
@@ -39,7 +41,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet
* This method is typically used to associate a serialized offset with actual sources (which
* cannot be serialized).
*/
- def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = {
+ def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = {
assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " +
s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " +
s"Cannot continue.")
@@ -56,13 +58,13 @@ object OffsetSeq {
* Returns a [[OffsetSeq]] with a variable sequence of offsets.
* `nulls` in the sequence are converted to `None`s.
*/
- def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*)
+ def fill(offsets: OffsetV2*): OffsetSeq = OffsetSeq.fill(None, offsets: _*)
/**
* Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets.
* `nulls` in the sequence are converted to `None`s.
*/
- def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = {
+ def fill(metadata: Option[String], offsets: OffsetV2*): OffsetSeq = {
OffsetSeq(offsets.map(Option(_)), metadata.map(OffsetSeqMetadata.apply))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
index 2c8d7c7b0f3c5..8a05dade092c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets._
import scala.io.{Source => IOSource}
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
/**
* This class is used to log offsets to persistent files in HDFS.
@@ -47,7 +48,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String)
override protected def deserialize(in: InputStream): OffsetSeq = {
// called inside a try-finally where the underlying stream is closed in the caller
- def parseOffset(value: String): Offset = value match {
+ def parseOffset(value: String): OffsetV2 = value match {
case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null
case json => SerializedOffset(json)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 25283515b882f..932daef8965d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2Relation, StreamWriterCommitProgress}
-import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream
+import org.apache.spark.sql.sources.v2.Table
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, SparkDataStream}
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
import org.apache.spark.util.Clock
@@ -44,7 +45,7 @@ import org.apache.spark.util.Clock
trait ProgressReporter extends Logging {
case class ExecutionStats(
- inputRows: Map[BaseStreamingSource, Long],
+ inputRows: Map[SparkDataStream, Long],
stateOperators: Seq[StateOperatorProgress],
eventTimeStats: Map[String, String])
@@ -55,10 +56,10 @@ trait ProgressReporter extends Logging {
protected def triggerClock: Clock
protected def logicalPlan: LogicalPlan
protected def lastExecution: QueryExecution
- protected def newData: Map[BaseStreamingSource, LogicalPlan]
+ protected def newData: Map[SparkDataStream, LogicalPlan]
protected def sinkCommitProgress: Option[StreamWriterCommitProgress]
- protected def sources: Seq[BaseStreamingSource]
- protected def sink: BaseStreamingSink
+ protected def sources: Seq[SparkDataStream]
+ protected def sink: Table
protected def offsetSeqMetadata: OffsetSeqMetadata
protected def currentBatchId: Long
protected def sparkSession: SparkSession
@@ -67,8 +68,8 @@ trait ProgressReporter extends Logging {
// Local timestamps and counters.
private var currentTriggerStartTimestamp = -1L
private var currentTriggerEndTimestamp = -1L
- private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _
- private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _
+ private var currentTriggerStartOffsets: Map[SparkDataStream, String] = _
+ private var currentTriggerEndOffsets: Map[SparkDataStream, String] = _
// TODO: Restore this from the checkpoint when possible.
private var lastTriggerStartTimestamp = -1L
@@ -240,9 +241,9 @@ trait ProgressReporter extends Logging {
}
/** Extract number of input sources for each streaming source in plan */
- private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = {
+ private def extractSourceToNumInputRows(): Map[SparkDataStream, Long] = {
- def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = {
+ def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, Long] = {
tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source
}
@@ -262,7 +263,7 @@ trait ProgressReporter extends Logging {
val sourceToInputRowsTuples = lastExecution.executedPlan.collect {
case s: MicroBatchScanExec =>
val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L)
- val source = s.stream.asInstanceOf[BaseStreamingSource]
+ val source = s.stream
source -> numRows
}
logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
index 34bc085d920c1..190325fb7ec25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala
@@ -17,14 +17,21 @@
package org.apache.spark.sql.execution.streaming
+import java.util
+
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.sources.v2.{Table, TableCapability}
+import org.apache.spark.sql.types.StructType
/**
* An interface for systems that can collect the results of a streaming query. In order to preserve
* exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same
* batch.
+ *
+ * Note that, we extends `Table` here, to make the v1 streaming sink API be compatible with
+ * data source v2.
*/
-trait Sink extends BaseStreamingSink {
+trait Sink extends Table {
/**
* Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if
@@ -38,4 +45,16 @@ trait Sink extends BaseStreamingSink {
* after data is consumed by sink successfully.
*/
def addBatch(batchId: Long, data: DataFrame): Unit
+
+ override def name: String = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def schema: StructType = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def capabilities: util.Set[TableCapability] = {
+ throw new IllegalStateException("should not be called.")
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
index dbbd59e06909c..7f66d0b055cc3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala
@@ -18,14 +18,19 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream
import org.apache.spark.sql.types.StructType
/**
* A source of continually arriving data for a streaming query. A [[Source]] must have a
* monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark
* will regularly query each [[Source]] to see if any more data is available.
+ *
+ * Note that, we extends `SparkDataStream` here, to make the v1 streaming source API be compatible
+ * with data source v2.
*/
-trait Source extends BaseStreamingSource {
+trait Source extends SparkDataStream {
/** Returns the schema of the data from this source */
def schema: StructType
@@ -62,6 +67,15 @@ trait Source extends BaseStreamingSource {
*/
def commit(end: Offset) : Unit = {}
- /** Stop this source and free any resources it has allocated. */
- def stop(): Unit
+ override def initialOffset(): OffsetV2 = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def deserializeOffset(json: String): OffsetV2 = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def commit(end: OffsetV2): Unit = {
+ throw new IllegalStateException("should not be called.")
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 90f7b477103ae..4c08b3aa78666 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -24,6 +24,7 @@ import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit}
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.locks.{Condition, ReentrantLock}
+import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MutableMap}
import scala.util.control.NonFatal
@@ -34,11 +35,17 @@ import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.StreamingExplainCommand
import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.{SupportsWrite, Table}
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}
+import org.apache.spark.sql.sources.v2.writer.SupportsTruncate
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}
/** States for [[StreamExecution]]'s lifecycle. */
@@ -55,14 +62,15 @@ case object RECONFIGURING extends State
* and the results are committed transactionally to the given [[Sink]].
*
* @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without
- * errors
+ * errors. Checkpoint deletion can be forced with the appropriate
+ * Spark configuration.
*/
abstract class StreamExecution(
override val sparkSession: SparkSession,
override val name: String,
private val checkpointRoot: String,
analyzedPlan: LogicalPlan,
- val sink: BaseStreamingSink,
+ val sink: Table,
val trigger: Trigger,
val triggerClock: Clock,
val outputMode: OutputMode,
@@ -89,9 +97,47 @@ abstract class StreamExecution(
val resolvedCheckpointRoot = {
val checkpointPath = new Path(checkpointRoot)
val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
- fs.mkdirs(checkpointPath)
- checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString
+ if (sparkSession.conf.get(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED)
+ && StreamExecution.containsSpecialCharsInPath(checkpointPath)) {
+ // In Spark 2.4 and earlier, the checkpoint path is escaped 3 times (3 `Path.toUri.toString`
+ // calls). If this legacy checkpoint path exists, we will throw an error to tell the user how
+ // to migrate.
+ val legacyCheckpointDir =
+ new Path(new Path(checkpointPath.toUri.toString).toUri.toString).toUri.toString
+ val legacyCheckpointDirExists =
+ try {
+ fs.exists(new Path(legacyCheckpointDir))
+ } catch {
+ case NonFatal(e) =>
+ // We may not have access to this directory. Don't fail the query if that happens.
+ logWarning(e.getMessage, e)
+ false
+ }
+ if (legacyCheckpointDirExists) {
+ throw new SparkException(
+ s"""Error: we detected a possible problem with the location of your checkpoint and you
+ |likely need to move it before restarting this query.
+ |
+ |Earlier version of Spark incorrectly escaped paths when writing out checkpoints for
+ |structured streaming. While this was corrected in Spark 3.0, it appears that your
+ |query was started using an earlier version that incorrectly handled the checkpoint
+ |path.
+ |
+ |Correct Checkpoint Directory: $checkpointPath
+ |Incorrect Checkpoint Directory: $legacyCheckpointDir
+ |
+ |Please move the data from the incorrect directory to the correct one, delete the
+ |incorrect directory, and then restart this query. If you believe you are receiving
+ |this message in error, you can disable it with the SQL conf
+ |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}."""
+ .stripMargin)
+ }
+ }
+ val checkpointDir = checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory)
+ fs.mkdirs(checkpointDir)
+ checkpointDir.toString
}
+ logInfo(s"Checkpoint root $checkpointRoot resolved to $resolvedCheckpointRoot.")
def logicalPlan: LogicalPlan
@@ -160,7 +206,7 @@ abstract class StreamExecution(
/**
* A list of unique sources in the query plan. This will be set when generating logical plan.
*/
- @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty
+ @volatile protected var uniqueSources: Seq[SparkDataStream] = Seq.empty
/** Defines the internal state of execution */
protected val state = new AtomicReference[State](INITIALIZING)
@@ -169,7 +215,7 @@ abstract class StreamExecution(
var lastExecution: IncrementalExecution = _
/** Holds the most recent input data for each source. */
- protected var newData: Map[BaseStreamingSource, LogicalPlan] = _
+ protected var newData: Map[SparkDataStream, LogicalPlan] = _
@volatile
protected var streamDeathCause: StreamingQueryException = null
@@ -225,7 +271,7 @@ abstract class StreamExecution(
/** Returns the path of a file with `name` in the checkpoint directory. */
protected def checkpointFile(name: String): String =
- new Path(new Path(resolvedCheckpointRoot), name).toUri.toString
+ new Path(new Path(resolvedCheckpointRoot), name).toString
/**
* Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]]
@@ -335,10 +381,13 @@ abstract class StreamExecution(
postEvent(
new QueryTerminatedEvent(id, runId, exception.map(_.cause).map(Utils.exceptionString)))
- // Delete the temp checkpoint only when the query didn't fail
- if (deleteCheckpointOnStop && exception.isEmpty) {
+ // Delete the temp checkpoint when either force delete enabled or the query didn't fail
+ if (deleteCheckpointOnStop &&
+ (sparkSession.sessionState.conf
+ .getConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION) || exception.isEmpty)) {
val checkpointPath = new Path(resolvedCheckpointRoot)
try {
+ logInfo(s"Deleting checkpoint $checkpointPath.")
val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
fs.delete(checkpointPath, true)
} catch {
@@ -389,7 +438,7 @@ abstract class StreamExecution(
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
*/
- private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = {
+ private[sql] def awaitOffset(sourceIndex: Int, newOffset: OffsetV2, timeoutMs: Long): Unit = {
assertAwaitThread()
def notDone = {
val localCommittedOffsets = committedOffsets
@@ -532,6 +581,35 @@ abstract class StreamExecution(
Option(name).map(_ + "
").getOrElse("") +
s"id = $id
runId = $runId
batch = $batchDescription"
}
+
+ protected def createStreamingWrite(
+ table: SupportsWrite,
+ options: Map[String, String],
+ inputPlan: LogicalPlan): StreamingWrite = {
+ val writeBuilder = table.newWriteBuilder(new CaseInsensitiveStringMap(options.asJava))
+ .withQueryId(id.toString)
+ .withInputDataSchema(inputPlan.schema)
+ outputMode match {
+ case Append =>
+ writeBuilder.buildForStreaming()
+
+ case Complete =>
+ // TODO: we should do this check earlier when we have capability API.
+ require(writeBuilder.isInstanceOf[SupportsTruncate],
+ table.name + " does not support Complete mode.")
+ writeBuilder.asInstanceOf[SupportsTruncate].truncate().buildForStreaming()
+
+ case Update =>
+ // Although no v2 sinks really support Update mode now, but during tests we do want them
+ // to pretend to support Update mode, and treat Update mode same as Append mode.
+ if (Utils.isTesting) {
+ writeBuilder.buildForStreaming()
+ } else {
+ throw new IllegalArgumentException(
+ "Data source v2 streaming sinks does not support Update mode.")
+ }
+ }
+ }
}
object StreamExecution {
@@ -568,6 +646,11 @@ object StreamExecution {
case _ =>
false
}
+
+ /** Whether the path contains special chars that will be escaped when converting to a `URI`. */
+ def containsSpecialCharsInPath(path: Path): Boolean = {
+ path.toUri.getPath != new Path(path.toUri.toString).toUri.getPath
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
index 8531070b1bc49..8783eaa0e68b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala
@@ -19,32 +19,35 @@ package org.apache.spark.sql.execution.streaming
import scala.collection.{immutable, GenTraversableOnce}
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}
+
+
/**
* A helper class that looks like a Map[Source, Offset].
*/
class StreamProgress(
- val baseMap: immutable.Map[BaseStreamingSource, Offset] =
- new immutable.HashMap[BaseStreamingSource, Offset])
- extends scala.collection.immutable.Map[BaseStreamingSource, Offset] {
+ val baseMap: immutable.Map[SparkDataStream, OffsetV2] =
+ new immutable.HashMap[SparkDataStream, OffsetV2])
+ extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] {
- def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = {
+ def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = {
OffsetSeq(source.map(get), Some(metadata))
}
override def toString: String =
baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}")
- override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = {
+ override def +[B1 >: OffsetV2](kv: (SparkDataStream, B1)): Map[SparkDataStream, B1] = {
baseMap + kv
}
- override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key)
+ override def get(key: SparkDataStream): Option[OffsetV2] = baseMap.get(key)
- override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator
+ override def iterator: Iterator[(SparkDataStream, OffsetV2)] = baseMap.iterator
- override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key
+ override def -(key: SparkDataStream): Map[SparkDataStream, OffsetV2] = baseMap - key
- def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = {
+ def ++(updates: GenTraversableOnce[(SparkDataStream, OffsetV2)]): StreamProgress = {
new StreamProgress(baseMap ++ updates)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index 83d38dcade7e6..142b6e7d18068 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -25,7 +25,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.LeafExecNode
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.sources.v2.{DataSourceV2, Table}
+import org.apache.spark.sql.sources.v2.{Table, TableProvider}
+import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
object StreamingRelation {
def apply(dataSource: DataSource): StreamingRelation = {
@@ -62,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output:
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
*/
case class StreamingExecutionRelation(
- source: BaseStreamingSource,
+ source: SparkDataStream,
output: Seq[Attribute])(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
@@ -86,16 +88,16 @@ case class StreamingExecutionRelation(
// know at read time whether the query is continuous or not, so we need to be able to
// swap a V1 relation back in.
/**
- * Used to link a [[DataSourceV2]] into a streaming
+ * Used to link a [[TableProvider]] into a streaming
* [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating
* a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]],
* and should be converted before passing to [[StreamExecution]].
*/
case class StreamingRelationV2(
- dataSource: DataSourceV2,
+ source: TableProvider,
sourceName: String,
table: Table,
- extraOptions: Map[String, String],
+ extraOptions: CaseInsensitiveStringMap,
output: Seq[Attribute],
v1Relation: Option[StreamingRelation])(session: SparkSession)
extends LeafNode with MultiInstanceRelation {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
index 9c5c16f4f5d13..9ae39c79c5156 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala
@@ -17,30 +17,30 @@
package org.apache.spark.sql.execution.streaming
+import java.util
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql._
-import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport
+import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.writer.{SupportsTruncate, WriteBuilder}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
extends BaseRelation {
override def schema: StructType = data.schema
}
-class ConsoleSinkProvider extends DataSourceV2
- with StreamingWriteSupportProvider
+class ConsoleSinkProvider extends TableProvider
with DataSourceRegister
with CreatableRelationProvider {
- override def createStreamingWriteSupport(
- queryId: String,
- schema: StructType,
- mode: OutputMode,
- options: DataSourceOptions): StreamingWriteSupport = {
- new ConsoleWriteSupport(schema, options)
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ ConsoleTable
}
def createRelation(
@@ -60,3 +60,33 @@ class ConsoleSinkProvider extends DataSourceV2
def shortName(): String = "console"
}
+
+object ConsoleTable extends Table with SupportsWrite {
+
+ override def name(): String = "console"
+
+ override def schema(): StructType = StructType(Nil)
+
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.STREAMING_WRITE).asJava
+ }
+
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
+ new WriteBuilder with SupportsTruncate {
+ private var inputSchema: StructType = _
+
+ override def withInputDataSchema(schema: StructType): WriteBuilder = {
+ this.inputSchema = schema
+ this
+ }
+
+ // Do nothing for truncate. Console sink is special that it just prints all the records.
+ override def truncate(): WriteBuilder = this
+
+ override def buildForStreaming(): StreamingWrite = {
+ assert(inputSchema != null)
+ new ConsoleWrite(inputSchema, options)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index b22795d207760..5475becc5bff4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.streaming.continuous
import java.util.UUID
import java.util.concurrent.TimeUnit
+import java.util.concurrent.atomic.AtomicReference
import java.util.function.UnaryOperator
-import scala.collection.JavaConverters._
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.SparkEnv
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _}
import org.apache.spark.sql.sources.v2
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead}
+import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, TableCapability}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
import org.apache.spark.util.Clock
@@ -42,7 +42,7 @@ class ContinuousExecution(
name: String,
checkpointRoot: String,
analyzedPlan: LogicalPlan,
- sink: StreamingWriteSupportProvider,
+ sink: SupportsWrite,
trigger: Trigger,
triggerClock: Clock,
outputMode: OutputMode,
@@ -57,26 +57,29 @@ class ContinuousExecution(
// For use only in test harnesses.
private[sql] var currentEpochCoordinatorId: String = _
- override val logicalPlan: LogicalPlan = {
+ // Throwable that caused the execution to fail
+ private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null)
+
+ override val logicalPlan: WriteToContinuousDataSource = {
val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]()
var nextSourceId = 0
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
val _logicalPlan = analyzedPlan.transform {
- case s @ StreamingRelationV2(
- ds, dsName, table: SupportsContinuousRead, options, output, _) =>
+ case s @ StreamingRelationV2(ds, sourceName, table: SupportsRead, options, output, _) =>
+ if (!table.supports(TableCapability.CONTINUOUS_READ)) {
+ throw new UnsupportedOperationException(
+ s"Data source $sourceName does not support continuous processing.")
+ }
+
v2ToRelationMap.getOrElseUpdate(s, {
val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId"
nextSourceId += 1
- logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]")
- val dsOptions = new DataSourceOptions(options.asJava)
+ logInfo(s"Reading table [$table] from DataSourceV2 named '$sourceName' [$ds]")
// TODO: operator pushdown.
- val scan = table.newScanBuilder(dsOptions).build()
+ val scan = table.newScanBuilder(options).build()
val stream = scan.toContinuousStream(metadataPath)
StreamingDataSourceV2Relation(output, scan, stream)
})
-
- case StreamingRelationV2(_, sourceName, _, _, _, _) =>
- throw new UnsupportedOperationException(
- s"Data source $sourceName does not support continuous processing.")
}
sources = _logicalPlan.collect {
@@ -84,7 +87,9 @@ class ContinuousExecution(
}
uniqueSources = sources.distinct
- _logicalPlan
+ // TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
+ WriteToContinuousDataSource(
+ createStreamingWrite(sink, extraOptions, _logicalPlan), _logicalPlan)
}
private val triggerExecutor = trigger match {
@@ -174,17 +179,10 @@ class ContinuousExecution(
"CurrentTimestamp and CurrentDate not yet supported for continuous processing")
}
- val writer = sink.createStreamingWriteSupport(
- s"$runId",
- withNewSources.schema,
- outputMode,
- new DataSourceOptions(extraOptions.asJava))
- val planWithSink = WriteToContinuousDataSource(writer, withNewSources)
-
reportTimeTaken("queryPlanning") {
lastExecution = new IncrementalExecution(
sparkSessionForQuery,
- planWithSink,
+ withNewSources,
outputMode,
checkpointFile("state"),
id,
@@ -194,7 +192,7 @@ class ContinuousExecution(
lastExecution.executedPlan // Force the lazy generation of execution plan
}
- val stream = planWithSink.collect {
+ val stream = withNewSources.collect {
case relation: StreamingDataSourceV2Relation =>
relation.stream.asInstanceOf[ContinuousStream]
}.head
@@ -214,9 +212,14 @@ class ContinuousExecution(
trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString)
// Use the parent Spark session for the endpoint since it's where this query ID is registered.
- val epochEndpoint =
- EpochCoordinatorRef.create(
- writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get)
+ val epochEndpoint = EpochCoordinatorRef.create(
+ logicalPlan.write,
+ stream,
+ this,
+ epochCoordinatorId,
+ currentBatchId,
+ sparkSession,
+ SparkEnv.get)
val epochUpdateThread = new Thread(new Runnable {
override def run: Unit = {
try {
@@ -258,19 +261,40 @@ class ContinuousExecution(
lastExecution.toRdd
}
}
+
+ val f = failure.get()
+ if (f != null) {
+ throw f
+ }
} catch {
case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) &&
state.get() == RECONFIGURING =>
logInfo(s"Query $id ignoring exception from reconfiguring: $t")
// interrupted by reconfiguration - swallow exception so we can restart the query
} finally {
- epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
- SparkEnv.get.rpcEnv.stop(epochEndpoint)
-
- epochUpdateThread.interrupt()
- epochUpdateThread.join()
-
- sparkSession.sparkContext.cancelJobGroup(runId.toString)
+ // The above execution may finish before getting interrupted, for example, a Spark job having
+ // 0 partitions will complete immediately. Then the interrupted status will sneak here.
+ //
+ // To handle this case, we do the two things here:
+ //
+ // 1. Clean up the resources in `queryExecutionThread.runUninterruptibly`. This may increase
+ // the waiting time of `stop` but should be minor because the operations here are very fast
+ // (just sending an RPC message in the same process and stopping a very simple thread).
+ // 2. Clear the interrupted status at the end so that it won't impact the `runContinuous`
+ // call. We may clear the interrupted status set by `stop`, but it doesn't affect the query
+ // termination because `runActivatedStream` will check `state` and exit accordingly.
+ queryExecutionThread.runUninterruptibly {
+ try {
+ epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
+ } finally {
+ SparkEnv.get.rpcEnv.stop(epochEndpoint)
+ epochUpdateThread.interrupt()
+ epochUpdateThread.join()
+ // The following line must be the last line because it may fail if SparkContext is stopped
+ sparkSession.sparkContext.cancelJobGroup(runId.toString)
+ }
+ }
+ Thread.interrupted()
}
}
@@ -370,6 +394,35 @@ class ContinuousExecution(
}
}
+ /**
+ * Stores error and stops the query execution thread to terminate the query in new thread.
+ */
+ def stopInNewThread(error: Throwable): Unit = {
+ if (failure.compareAndSet(null, error)) {
+ logError(s"Query $prettyIdString received exception $error")
+ stopInNewThread()
+ }
+ }
+
+ /**
+ * Stops the query execution thread to terminate the query in new thread.
+ */
+ private def stopInNewThread(): Unit = {
+ new Thread("stop-continuous-execution") {
+ setDaemon(true)
+
+ override def run(): Unit = {
+ try {
+ ContinuousExecution.this.stop()
+ } catch {
+ case e: Throwable =>
+ logError(e.getMessage, e)
+ throw e
+ }
+ }
+ }.start()
+ }
+
/**
* Stops the query execution thread to terminate the query.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
index 48ff70f9c9d07..d55f71c7be830 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala
@@ -23,17 +23,13 @@ import org.json4s.jackson.Serialization
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair}
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
case class RateStreamPartitionOffset(
partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset
-class RateStreamContinuousStream(
- rowsPerSecond: Long,
- numPartitions: Int,
- options: DataSourceOptions) extends ContinuousStream {
+class RateStreamContinuousStream(rowsPerSecond: Long, numPartitions: Int) extends ContinuousStream {
implicit val defaultFormats: DefaultFormats = DefaultFormats
val creationTime = System.currentTimeMillis()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
index e7bc71394061e..2263b42870a65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
@@ -34,9 +34,9 @@ import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.{Offset => _, _}
import org.apache.spark.sql.execution.streaming.sources.TextSocketReader
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.RpcUtils
@@ -49,7 +49,7 @@ import org.apache.spark.util.RpcUtils
* buckets and serves the messages to the executors via a RPC endpoint.
*/
class TextSocketContinuousStream(
- host: String, port: Int, numPartitions: Int, options: DataSourceOptions)
+ host: String, port: Int, numPartitions: Int, options: CaseInsensitiveStringMap)
extends ContinuousStream with Logging {
implicit val defaultFormats: DefaultFormats = DefaultFormats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index d1bda79f4b6ef..decf524f7167c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.util.RpcUtils
private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
@@ -82,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging {
* Create a reference to a new [[EpochCoordinator]].
*/
def create(
- writeSupport: StreamingWriteSupport,
+ writeSupport: StreamingWrite,
stream: ContinuousStream,
query: ContinuousExecution,
epochCoordinatorId: String,
@@ -115,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging {
* have both committed and reported an end offset for a given epoch.
*/
private[continuous] class EpochCoordinator(
- writeSupport: StreamingWriteSupport,
+ writeSupport: StreamingWrite,
stream: ContinuousStream,
query: ContinuousExecution,
startEpoch: Long,
@@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
+ private val epochBacklogQueueSize =
+ session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize
+
private var queryWritesStopped: Boolean = false
private var numReaderPartitions: Int = _
@@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator(
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
partitionCommits.put((epoch, partitionId), message)
resolveCommitsAtEpoch(epoch)
+ checkProcessingQueueBoundaries()
}
case ReportPartitionOffset(partitionId, epoch, offset) =>
@@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator(
query.addOffset(epoch, stream, thisEpochOffsets.toSeq)
resolveCommitsAtEpoch(epoch)
}
+ checkProcessingQueueBoundaries()
+ }
+
+ private def checkProcessingQueueBoundaries() = {
+ if (partitionOffsets.size > epochBacklogQueueSize) {
+ query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " +
+ "exceeded its maximum"))
+ }
+ if (partitionCommits.size > epochBacklogQueueSize) {
+ query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " +
+ "exceeded its maximum"))
+ }
+ if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) {
+ query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " +
+ "exceeded its maximum"))
+ }
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
index 7ad21cc304e7c..54f484c4adae3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
/**
* The logical plan for writing data in a continuous stream.
*/
-case class WriteToContinuousDataSource(
- writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan {
+case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan)
+ extends LogicalPlan {
override def children: Seq[LogicalPlan] = Seq(query)
override def output: Seq[Attribute] = Nil
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index 2178466d63142..2f3af6a6544c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -26,21 +26,22 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
/**
- * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]].
+ * The physical plan for writing data into a continuous processing [[StreamingWrite]].
*/
-case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan)
- extends UnaryExecNode with Logging {
+case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan)
+ extends UnaryExecNode with Logging {
+
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
override protected def doExecute(): RDD[InternalRow] = {
- val writerFactory = writeSupport.createStreamingWriterFactory()
+ val writerFactory = write.createStreamingWriterFactory()
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
- logInfo(s"Start processing data source write support: $writeSupport. " +
+ logInfo(s"Start processing data source write support: $write. " +
s"The input RDD has ${rdd.partitions.length} partitions.")
EpochCoordinatorRef.get(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index e71f81caeb974..df149552dfb30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -17,27 +17,26 @@
package org.apache.spark.sql.execution.streaming
+import java.util
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
-import scala.util.control.NonFatal
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.encoderFor
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
-import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
@@ -50,7 +49,7 @@ object MemoryStream {
/**
* A base class for memory stream implementations. Supports adding data and resetting.
*/
-abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource {
+abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream {
val encoder = encoderFor[A]
protected val attributes = encoder.schema.toAttributes
@@ -62,10 +61,12 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
}
- def addData(data: A*): Offset = {
+ def addData(data: A*): OffsetV2 = {
addData(data.toTraversable)
}
+ def addData(data: TraversableOnce[A]): OffsetV2
+
def fullSchema(): StructType = encoder.schema
protected val logicalPlan: LogicalPlan = {
@@ -73,30 +74,43 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas
MemoryStreamTableProvider,
"memory",
new MemoryStreamTable(this),
- Map.empty,
+ CaseInsensitiveStringMap.empty(),
attributes,
None)(sqlContext.sparkSession)
}
- def addData(data: TraversableOnce[A]): Offset
+ override def initialOffset(): OffsetV2 = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def deserializeOffset(json: String): OffsetV2 = {
+ throw new IllegalStateException("should not be called.")
+ }
+
+ override def commit(end: OffsetV2): Unit = {
+ throw new IllegalStateException("should not be called.")
+ }
}
// This class is used to indicate the memory stream data source. We don't actually use it, as
// memory stream is for test only and we never look it up by name.
object MemoryStreamTableProvider extends TableProvider {
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new IllegalStateException("MemoryStreamTableProvider should not be used.")
}
}
-class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table
- with SupportsMicroBatchRead with SupportsContinuousRead {
+class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead {
override def name(): String = "MemoryStreamDataSource"
override def schema(): StructType = stream.fullSchema()
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MemoryStreamScanBuilder(stream)
}
}
@@ -212,22 +226,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}
override def commit(end: OffsetV2): Unit = synchronized {
- def check(newOffset: LongOffset): Unit = {
- val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
-
- if (offsetDiff < 0) {
- sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
- }
+ val newOffset = end.asInstanceOf[LongOffset]
+ val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
- batches.trimStart(offsetDiff)
- lastOffsetCommitted = newOffset
+ if (offsetDiff < 0) {
+ sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end")
}
- LongOffset.convert(end) match {
- case Some(lo) => check(lo)
- case None => sys.error(s"MemoryStream.commit() received an offset ($end) " +
- "that did not originate with an instance of this class")
- }
+ batches.trimStart(offsetDiff)
+ lastOffsetCommitted = newOffset
}
override def stop() {}
@@ -262,93 +269,3 @@ object MemoryStreamReaderFactory extends PartitionReaderFactory {
}
}
}
-
-/** A common trait for MemorySinks with methods used for testing */
-trait MemorySinkBase extends BaseStreamingSink {
- def allData: Seq[Row]
- def latestBatchData: Seq[Row]
- def dataSinceBatch(sinceBatchId: Long): Seq[Row]
- def latestBatchId: Option[Long]
-}
-
-/**
- * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
- * tests and does not provide durability.
- */
-class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink
- with MemorySinkBase with Logging {
-
- private case class AddedData(batchId: Long, data: Array[Row])
-
- /** An order list of batches that have been written to this [[Sink]]. */
- @GuardedBy("this")
- private val batches = new ArrayBuffer[AddedData]()
-
- /** Returns all rows that are stored in this [[Sink]]. */
- def allData: Seq[Row] = synchronized {
- batches.flatMap(_.data)
- }
-
- def latestBatchId: Option[Long] = synchronized {
- batches.lastOption.map(_.batchId)
- }
-
- def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) }
-
- def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized {
- batches.filter(_.batchId > sinceBatchId).flatMap(_.data)
- }
-
- def toDebugString: String = synchronized {
- batches.map { case AddedData(batchId, data) =>
- val dataStr = try data.mkString(" ") catch {
- case NonFatal(e) => "[Error converting to string]"
- }
- s"$batchId: $dataStr"
- }.mkString("\n")
- }
-
- override def addBatch(batchId: Long, data: DataFrame): Unit = {
- val notCommitted = synchronized {
- latestBatchId.isEmpty || batchId > latestBatchId.get
- }
- if (notCommitted) {
- logDebug(s"Committing batch $batchId to $this")
- outputMode match {
- case Append | Update =>
- val rows = AddedData(batchId, data.collect())
- synchronized { batches += rows }
-
- case Complete =>
- val rows = AddedData(batchId, data.collect())
- synchronized {
- batches.clear()
- batches += rows
- }
-
- case _ =>
- throw new IllegalArgumentException(
- s"Output mode $outputMode is not supported by MemorySink")
- }
- } else {
- logDebug(s"Skipping already committed batch: $batchId")
- }
- }
-
- def clear(): Unit = synchronized {
- batches.clear()
- }
-
- override def toString(): String = "MemorySink"
-}
-
-/**
- * Used to query the data that has been written into a [[MemorySink]].
- */
-case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
- def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
-
- private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes)
-
- override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
similarity index 92%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
index 833e62f35ede1..dbe242784986d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/** Common methods used to create writes for the the console sink */
-class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions)
- extends StreamingWriteSupport with Logging {
+class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap)
+ extends StreamingWrite with Logging {
// Number of rows to display, by default 20 rows
protected val numRowsToShow = options.getInt("numRows", 20)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
similarity index 60%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
index 4218fd51ad206..6da1b3a49c442 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala
@@ -17,68 +17,88 @@
package org.apache.spark.sql.execution.streaming.sources
+import java.util
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.{ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.python.PythonForeachWriter
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider}
-import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage}
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
- * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified
- * [[ForeachWriter]].
+ * A write-only table for forwarding data into the specified [[ForeachWriter]].
*
* @param writer The [[ForeachWriter]] to process all data.
* @param converter An object to convert internal rows to target type T. Either it can be
* a [[ExpressionEncoder]] or a direct converter function.
* @tparam T The expected type of the sink.
*/
-case class ForeachWriteSupportProvider[T](
+case class ForeachWriterTable[T](
writer: ForeachWriter[T],
converter: Either[ExpressionEncoder[T], InternalRow => T])
- extends StreamingWriteSupportProvider {
-
- override def createStreamingWriteSupport(
- queryId: String,
- schema: StructType,
- mode: OutputMode,
- options: DataSourceOptions): StreamingWriteSupport = {
- new StreamingWriteSupport {
- override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
- override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
-
- override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
- val rowConverter: InternalRow => T = converter match {
- case Left(enc) =>
- val boundEnc = enc.resolveAndBind(
- schema.toAttributes,
- SparkSession.getActiveSession.get.sessionState.analyzer)
- boundEnc.fromRow
- case Right(func) =>
- func
- }
- ForeachWriterFactory(writer, rowConverter)
+ extends Table with SupportsWrite {
+
+ override def name(): String = "ForeachSink"
+
+ override def schema(): StructType = StructType(Nil)
+
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.STREAMING_WRITE).asJava
+ }
+
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
+ new WriteBuilder with SupportsTruncate {
+ private var inputSchema: StructType = _
+
+ override def withInputDataSchema(schema: StructType): WriteBuilder = {
+ this.inputSchema = schema
+ this
}
- override def toString: String = "ForeachSink"
+ // Do nothing for truncate. Foreach sink is special that it just forwards all the records to
+ // ForeachWriter.
+ override def truncate(): WriteBuilder = this
+
+ override def buildForStreaming(): StreamingWrite = {
+ new StreamingWrite {
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+
+ override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
+ val rowConverter: InternalRow => T = converter match {
+ case Left(enc) =>
+ val boundEnc = enc.resolveAndBind(
+ inputSchema.toAttributes,
+ SparkSession.getActiveSession.get.sessionState.analyzer)
+ boundEnc.fromRow
+ case Right(func) =>
+ func
+ }
+ ForeachWriterFactory(writer, rowConverter)
+ }
+ }
+ }
}
}
}
-object ForeachWriteSupportProvider {
+object ForeachWriterTable {
def apply[T](
writer: ForeachWriter[T],
- encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = {
+ encoder: ExpressionEncoder[T]): ForeachWriterTable[_] = {
writer match {
case pythonWriter: PythonForeachWriter =>
- new ForeachWriteSupportProvider[UnsafeRow](
+ new ForeachWriterTable[UnsafeRow](
pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow]))
case _ =>
- new ForeachWriteSupportProvider[T](writer, Left(encoder))
+ new ForeachWriterTable[T](writer, Left(encoder))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
index 143235efee81d..f3951897ea747 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala
@@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.streaming.sources
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
/**
* A [[BatchWrite]] used to hook V2 stream writers into a microbatch plan. It implements
* the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped
* streaming write support.
*/
-class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWriteSupport) extends BatchWrite {
+class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends BatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = {
writeSupport.commit(eppchId, messages)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala
index a8feed34b96dc..5403eafd54b61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala
@@ -28,9 +28,9 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.{ManualClock, SystemClock}
class RateStreamMicroBatchStream(
@@ -38,7 +38,7 @@ class RateStreamMicroBatchStream(
// The default values here are used in tests.
rampUpTimeSeconds: Long = 0,
numPartitions: Int = 1,
- options: DataSourceOptions,
+ options: CaseInsensitiveStringMap,
checkpointLocation: String)
extends MicroBatchStream with Logging {
import RateStreamProvider._
@@ -155,7 +155,7 @@ class RateStreamMicroBatchStream(
override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " +
s"rampUpTimeSeconds=$rampUpTimeSeconds, " +
- s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}"
+ s"numPartitions=${options.getOrDefault(NUM_PARTITIONS, "default")}"
}
case class RateStreamMicroBatchInputPartition(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
index 075c6b9362ba2..8dbae9f787cf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala
@@ -17,6 +17,10 @@
package org.apache.spark.sql.execution.streaming.sources
+import java.util
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream
@@ -25,6 +29,7 @@ import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A source that generates increment long values with timestamps. Each generated row has two
@@ -40,18 +45,17 @@ import org.apache.spark.sql.types._
* generated rows. The source will try its best to reach `rowsPerSecond`, but the query may
* be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed.
*/
-class RateStreamProvider extends DataSourceV2
- with TableProvider with DataSourceRegister {
+class RateStreamProvider extends TableProvider with DataSourceRegister {
import RateStreamProvider._
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1)
if (rowsPerSecond <= 0) {
throw new IllegalArgumentException(
s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive")
}
- val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME).orElse(null))
+ val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME))
.map(JavaUtils.timeStringAsSec)
.getOrElse(0L)
if (rampUpTimeSeconds < 0) {
@@ -75,7 +79,7 @@ class RateStreamTable(
rowsPerSecond: Long,
rampUpTimeSeconds: Long,
numPartitions: Int)
- extends Table with SupportsMicroBatchRead with SupportsContinuousRead {
+ extends Table with SupportsRead {
override def name(): String = {
s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " +
@@ -84,7 +88,11 @@ class RateStreamTable(
override def schema(): StructType = RateStreamProvider.SCHEMA
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder {
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
override def build(): Scan = new Scan {
override def readSchema(): StructType = RateStreamProvider.SCHEMA
@@ -94,7 +102,7 @@ class RateStreamTable(
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
- new RateStreamContinuousStream(rowsPerSecond, numPartitions, options)
+ new RateStreamContinuousStream(rowsPerSecond, numPartitions)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala
index 540131c8de8a1..dd8d89238008e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala
@@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming.LongOffset
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset}
import org.apache.spark.unsafe.types.UTF8String
@@ -39,8 +38,7 @@ import org.apache.spark.unsafe.types.UTF8String
* and debugging. This MicroBatchReadSupport will *not* work in production applications due to
* multiple reasons, including no support for fault recovery.
*/
-class TextSocketMicroBatchStream(
- host: String, port: Int, numPartitions: Int, options: DataSourceOptions)
+class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int)
extends MicroBatchStream with Logging {
@GuardedBy("this")
@@ -155,10 +153,7 @@ class TextSocketMicroBatchStream(
}
override def commit(end: Offset): Unit = synchronized {
- val newOffset = LongOffset.convert(end).getOrElse(
- sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " +
- s"originate with an instance of this class")
- )
+ val newOffset = end.asInstanceOf[LongOffset]
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
index c3b24a8f65dd9..e714859c16ddd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala
@@ -18,8 +18,10 @@
package org.apache.spark.sql.execution.streaming.sources
import java.text.SimpleDateFormat
+import java.util
import java.util.Locale
+import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}
import org.apache.spark.internal.Logging
@@ -30,21 +32,21 @@ import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class TextSocketSourceProvider extends DataSourceV2
- with TableProvider with DataSourceRegister with Logging {
+class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging {
- private def checkParameters(params: DataSourceOptions): Unit = {
+ private def checkParameters(params: CaseInsensitiveStringMap): Unit = {
logWarning("The socket source should not be used for production applications! " +
"It does not support recovery.")
- if (!params.get("host").isPresent) {
+ if (!params.containsKey("host")) {
throw new AnalysisException("Set a host to read from with option(\"host\", ...).")
}
- if (!params.get("port").isPresent) {
+ if (!params.containsKey("port")) {
throw new AnalysisException("Set a port to read from with option(\"port\", ...).")
}
Try {
- params.get("includeTimestamp").orElse("false").toBoolean
+ params.getBoolean("includeTimestamp", false)
} match {
case Success(_) =>
case Failure(_) =>
@@ -52,10 +54,10 @@ class TextSocketSourceProvider extends DataSourceV2
}
}
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
checkParameters(options)
new TextSocketTable(
- options.get("host").get,
+ options.get("host"),
options.getInt("port", -1),
options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism),
options.getBoolean("includeTimestamp", false))
@@ -66,7 +68,7 @@ class TextSocketSourceProvider extends DataSourceV2
}
class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimestamp: Boolean)
- extends Table with SupportsMicroBatchRead with SupportsContinuousRead {
+ extends Table with SupportsRead {
override def name(): String = s"Socket[$host:$port]"
@@ -78,12 +80,16 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest
}
}
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder {
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava
+ }
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder {
override def build(): Scan = new Scan {
override def readSchema(): StructType = schema()
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
- new TextSocketMicroBatchStream(host, port, numPartitions, options)
+ new TextSocketMicroBatchStream(host, port, numPartitions)
}
override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
new file mode 100644
index 0000000000000..a3f58fa966fe8
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
+
+/**
+ * The logical plan for writing data to a micro-batch stream.
+ *
+ * Note that this logical plan does not have a corresponding physical plan, as it will be converted
+ * to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution.
+ */
+case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan)
+ extends LogicalPlan {
+ override def children: Seq[LogicalPlan] = Seq(query)
+ override def output: Seq[Attribute] = Nil
+
+ def createPlan(batchId: Long): WriteToDataSourceV2 = {
+ WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
similarity index 69%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
index c50dc7bcb8da1..de8d00d4ac348 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala
@@ -17,8 +17,10 @@
package org.apache.spark.sql.execution.streaming.sources
+import java.util
import javax.annotation.concurrent.GuardedBy
+import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
@@ -30,27 +32,46 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
-import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update}
-import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink}
-import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider}
+import org.apache.spark.sql.execution.streaming.Sink
+import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
*/
-class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider
- with MemorySinkBase with Logging {
+class MemorySink extends Table with SupportsWrite with Logging {
- override def createStreamingWriteSupport(
- queryId: String,
- schema: StructType,
- mode: OutputMode,
- options: DataSourceOptions): StreamingWriteSupport = {
- new MemoryStreamingWriteSupport(this, mode, schema)
+ override def name(): String = "MemorySink"
+
+ override def schema(): StructType = StructType(Nil)
+
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(TableCapability.STREAMING_WRITE).asJava
+ }
+
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
+ new WriteBuilder with SupportsTruncate {
+ private var needTruncate: Boolean = false
+ private var inputSchema: StructType = _
+
+ override def truncate(): WriteBuilder = {
+ this.needTruncate = true
+ this
+ }
+
+ override def withInputDataSchema(schema: StructType): WriteBuilder = {
+ this.inputSchema = schema
+ this
+ }
+
+ override def buildForStreaming(): StreamingWrite = {
+ new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate)
+ }
+ }
}
private case class AddedData(batchId: Long, data: Array[Row])
@@ -85,27 +106,20 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider
}.mkString("\n")
}
- def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = {
+ def write(batchId: Long, needTruncate: Boolean, newRows: Array[Row]): Unit = {
val notCommitted = synchronized {
latestBatchId.isEmpty || batchId > latestBatchId.get
}
if (notCommitted) {
logDebug(s"Committing batch $batchId to $this")
- outputMode match {
- case Append | Update =>
- val rows = AddedData(batchId, newRows)
- synchronized { batches += rows }
-
- case Complete =>
- val rows = AddedData(batchId, newRows)
- synchronized {
- batches.clear()
- batches += rows
- }
-
- case _ =>
- throw new IllegalArgumentException(
- s"Output mode $outputMode is not supported by MemorySinkV2")
+ val rows = AddedData(batchId, newRows)
+ if (needTruncate) {
+ synchronized {
+ batches.clear()
+ batches += rows
+ }
+ } else {
+ synchronized { batches += rows }
}
} else {
logDebug(s"Skipping already committed batch: $batchId")
@@ -116,25 +130,25 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider
batches.clear()
}
- override def toString(): String = "MemorySinkV2"
+ override def toString(): String = "MemorySink"
}
case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row])
extends WriterCommitMessage {}
-class MemoryStreamingWriteSupport(
- val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType)
- extends StreamingWriteSupport {
+class MemoryStreamingWrite(
+ val sink: MemorySink, schema: StructType, needTruncate: Boolean)
+ extends StreamingWrite {
override def createStreamingWriterFactory: MemoryWriterFactory = {
- MemoryWriterFactory(outputMode, schema)
+ MemoryWriterFactory(schema)
}
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
val newRows = messages.flatMap {
case message: MemoryWriterCommitMessage => message.data
}
- sink.write(epochId, outputMode, newRows)
+ sink.write(epochId, needTruncate, newRows)
}
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
@@ -142,13 +156,13 @@ class MemoryStreamingWriteSupport(
}
}
-case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType)
+case class MemoryWriterFactory(schema: StructType)
extends DataWriterFactory with StreamingDataWriterFactory {
override def createWriter(
partitionId: Int,
taskId: Long): DataWriter[InternalRow] = {
- new MemoryDataWriter(partitionId, outputMode, schema)
+ new MemoryDataWriter(partitionId, schema)
}
override def createWriter(
@@ -159,7 +173,7 @@ case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType)
}
}
-class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType)
+class MemoryDataWriter(partition: Int, schema: StructType)
extends DataWriter[InternalRow] with Logging {
private val data = mutable.Buffer[Row]()
@@ -181,9 +195,9 @@ class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructTyp
/**
- * Used to query the data that has been written into a [[MemorySinkV2]].
+ * Used to query the data that has been written into a [[MemorySink]].
*/
-case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode {
+case class MemoryPlan(sink: MemorySink, override val output: Seq[Attribute]) extends LeafNode {
private val sizePerRow = EstimationUtils.getSizePerRow(output)
override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index a605dc640dc96..18029abb08dab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser}
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck}
import org.apache.spark.sql.streaming.StreamingQueryManager
import org.apache.spark.sql.util.ExecutionListenerManager
@@ -160,6 +161,7 @@ abstract class BaseSessionStateBuilder(
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallbackOrcDataSourceV2(session) +:
+ DataSourceResolution(conf, session.catalog(_)) +:
customResolutionRules
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
@@ -172,6 +174,8 @@ abstract class BaseSessionStateBuilder(
PreWriteCheck +:
PreReadCheck +:
HiveOnlyCheck +:
+ V2WriteSupportCheck +:
+ V2StreamingScanSupportCheck +:
customCheckRules
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index a10bd2218eb38..da4723e34c0d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -30,7 +30,9 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2}
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems,
@@ -173,22 +175,24 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
ds match {
case provider: TableProvider =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- ds = provider, conf = sparkSession.sessionState.conf)
+ source = provider, conf = sparkSession.sessionState.conf)
val options = sessionOptions ++ extraOptions
- val dsOptions = new DataSourceOptions(options.asJava)
+ val dsOptions = new CaseInsensitiveStringMap(options.asJava)
val table = userSpecifiedSchema match {
case Some(schema) => provider.getTable(dsOptions, schema)
case _ => provider.getTable(dsOptions)
}
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
- case _: SupportsMicroBatchRead | _: SupportsContinuousRead =>
+ case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>
Dataset.ofRows(
sparkSession,
StreamingRelationV2(
- provider, source, table, options, table.schema.toAttributes, v1Relation)(
+ provider, source, table, dsOptions, table.schema.toAttributes, v1Relation)(
sparkSession))
// fallback to v1
+ // TODO (SPARK-27483): we should move this fallback logic to an analyzer rule.
case _ => Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index ea596ba728c19..d051cf9c1d4a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -31,7 +31,9 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider
+import org.apache.spark.sql.sources.v2.{SupportsWrite, TableProvider}
+import org.apache.spark.sql.sources.v2.TableCapability._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -252,16 +254,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
if (extraOptions.get("queryName").isEmpty) {
throw new AnalysisException("queryName must be specified for memory sink")
}
- val (sink, resultDf) = trigger match {
- case _: ContinuousTrigger =>
- val s = new MemorySinkV2()
- val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes))
- (s, r)
- case _ =>
- val s = new MemorySink(df.schema, outputMode)
- val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
- (s, r)
- }
+ val sink = new MemorySink()
+ val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
val chkpointLoc = extraOptions.get("checkpointLocation")
val recoverFromChkpoint = outputMode == OutputMode.Complete()
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
@@ -278,7 +272,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
query
} else if (source == "foreach") {
assertNotPartitioned("foreach")
- val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc)
+ val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc)
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
@@ -304,30 +298,31 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
- val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
+ val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
- var options = extraOptions.toMap
- val sink = ds.getConstructor().newInstance() match {
- case w: StreamingWriteSupportProvider
- if !disabledSources.contains(w.getClass.getCanonicalName) =>
- val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
- w, df.sparkSession.sessionState.conf)
- options = sessionOptions ++ extraOptions
- w
- case _ =>
- val ds = DataSource(
- df.sparkSession,
- className = source,
- options = options,
- partitionColumns = normalizedParCols.getOrElse(Nil))
- ds.createSink(outputMode)
+ val useV1Source = disabledSources.contains(cls.getCanonicalName)
+
+ val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) {
+ val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider]
+ val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
+ source = provider, conf = df.sparkSession.sessionState.conf)
+ val options = sessionOptions ++ extraOptions
+ val dsOptions = new CaseInsensitiveStringMap(options.asJava)
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
+ provider.getTable(dsOptions) match {
+ case table: SupportsWrite if table.supports(STREAMING_WRITE) =>
+ table
+ case _ => createV1Sink()
+ }
+ } else {
+ createV1Sink()
}
df.sparkSession.sessionState.streamingQueryManager.startQuery(
- options.get("queryName"),
- options.get("checkpointLocation"),
+ extraOptions.get("queryName"),
+ extraOptions.get("checkpointLocation"),
df,
- options,
+ extraOptions.toMap,
sink,
outputMode,
useTempCheckpointLocation = source == "console" || source == "noop",
@@ -336,6 +331,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
}
}
+ private def createV1Sink(): Sink = {
+ val ds = DataSource(
+ df.sparkSession,
+ className = source,
+ options = extraOptions.toMap,
+ partitionColumns = normalizedParCols.getOrElse(Nil))
+ ds.createSink(outputMode)
+ }
+
/**
* Sets the output of the streaming query to be processed using the provided writer object.
* object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 881cd96cc9dc9..63fb9ed176b9f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution,
import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS
-import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider
+import org.apache.spark.sql.sources.v2.{SupportsWrite, Table}
import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
@@ -206,7 +206,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
extraOptions: Map[String, String],
- sink: BaseStreamingSink,
+ sink: Table,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean,
recoverFromCheckpointLocation: Boolean,
@@ -214,16 +214,20 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
triggerClock: Clock): StreamingQueryWrapper = {
var deleteCheckpointOnStop = false
val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified =>
- new Path(userSpecified).toUri.toString
+ new Path(userSpecified).toString
}.orElse {
df.sparkSession.sessionState.conf.checkpointLocation.map { location =>
- new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toUri.toString
+ new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toString
}
}.getOrElse {
if (useTempCheckpointLocation) {
- // Delete the temp checkpoint when a query is being stopped without errors.
deleteCheckpointOnStop = true
- Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath
+ val tempDir = Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath
+ logWarning("Temporary checkpoint location created which is deleted normally when" +
+ s" the query didn't fail: $tempDir. If it's required to delete it under any" +
+ s" circumstances, please set ${SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key} to" +
+ s" true. Important to know deleting temp checkpoint folder is best effort.")
+ tempDir
} else {
throw new AnalysisException(
"checkpointLocation must be specified either " +
@@ -254,7 +258,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
}
(sink, trigger) match {
- case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) =>
+ case (table: SupportsWrite, trigger: ContinuousTrigger) =>
if (operationCheckEnabled) {
UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
}
@@ -263,7 +267,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedName.orNull,
checkpointLocation,
analyzedPlan,
- v2Sink,
+ table,
trigger,
triggerClock,
outputMode,
@@ -308,7 +312,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
userSpecifiedCheckpointLocation: Option[String],
df: DataFrame,
extraOptions: Map[String, String],
- sink: BaseStreamingSink,
+ sink: Table,
outputMode: OutputMode,
useTempCheckpointLocation: Boolean = false,
recoverFromCheckpointLocation: Boolean = true,
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
index 2612b6185fd4c..255a9f887878b 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java
@@ -24,19 +24,19 @@
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
public class JavaAdvancedDataSourceV2 implements TableProvider {
@Override
- public Table getTable(DataSourceOptions options) {
+ public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
- public ScanBuilder newScanBuilder(DataSourceOptions options) {
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new AdvancedScanBuilder();
}
};
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
index d72ab5338aa8c..699859cfaebe1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java
@@ -21,11 +21,11 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
@@ -49,10 +49,10 @@ public PartitionReaderFactory createReaderFactory() {
}
@Override
- public Table getTable(DataSourceOptions options) {
+ public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
- public ScanBuilder newScanBuilder(DataSourceOptions options) {
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder();
}
};
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
index a513bfb26ef1c..391af5a306a16 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -20,15 +20,17 @@
import java.io.IOException;
import java.util.Arrays;
+import org.apache.spark.sql.catalog.v2.expressions.Expressions;
+import org.apache.spark.sql.catalog.v2.expressions.Transform;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution;
import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
public class JavaPartitionAwareDataSource implements TableProvider {
@@ -54,10 +56,15 @@ public Partitioning outputPartitioning() {
}
@Override
- public Table getTable(DataSourceOptions options) {
+ public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
- public ScanBuilder newScanBuilder(DataSourceOptions options) {
+ public Transform[] partitioning() {
+ return new Transform[] { Expressions.identity("i") };
+ }
+
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder();
}
};
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java
new file mode 100644
index 0000000000000..f3755e18b58d5
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.util.OptionalLong;
+
+import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.TableProvider;
+import org.apache.spark.sql.sources.v2.reader.InputPartition;
+import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
+import org.apache.spark.sql.sources.v2.reader.Statistics;
+import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+
+public class JavaReportStatisticsDataSource implements TableProvider {
+ class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics {
+ @Override
+ public Statistics estimateStatistics() {
+ return new Statistics() {
+ @Override
+ public OptionalLong sizeInBytes() {
+ return OptionalLong.of(80);
+ }
+
+ @Override
+ public OptionalLong numRows() {
+ return OptionalLong.of(10);
+ }
+ };
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ InputPartition[] partitions = new InputPartition[2];
+ partitions[0] = new JavaRangeInputPartition(0, 5);
+ partitions[1] = new JavaRangeInputPartition(5, 10);
+ return partitions;
+ }
+ }
+
+ @Override
+ public Table getTable(CaseInsensitiveStringMap options) {
+ return new JavaSimpleBatchTable() {
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
+ return new MyScanBuilder();
+ }
+ };
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
index 815d57ba94139..3800a94f88898 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java
@@ -17,11 +17,11 @@
package test.org.apache.spark.sql.sources.v2;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
public class JavaSchemaRequiredDataSource implements TableProvider {
@@ -45,7 +45,7 @@ public InputPartition[] planInputPartitions() {
}
@Override
- public Table getTable(DataSourceOptions options, StructType schema) {
+ public Table getTable(CaseInsensitiveStringMap options, StructType schema) {
return new JavaSimpleBatchTable() {
@Override
@@ -54,14 +54,14 @@ public StructType schema() {
}
@Override
- public ScanBuilder newScanBuilder(DataSourceOptions options) {
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder(schema);
}
};
}
@Override
- public Table getTable(DataSourceOptions options) {
+ public Table getTable(CaseInsensitiveStringMap options) {
throw new IllegalArgumentException("requires a user-supplied schema");
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java
index cb5954d5a6211..9b0eb610a206f 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java
@@ -18,15 +18,23 @@
package test.org.apache.spark.sql.sources.v2;
import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
-import org.apache.spark.sql.sources.v2.SupportsBatchRead;
+import org.apache.spark.sql.sources.v2.SupportsRead;
import org.apache.spark.sql.sources.v2.Table;
+import org.apache.spark.sql.sources.v2.TableCapability;
import org.apache.spark.sql.sources.v2.reader.*;
import org.apache.spark.sql.types.StructType;
-abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead {
+abstract class JavaSimpleBatchTable implements Table, SupportsRead {
+ private static final Set CAPABILITIES = new HashSet<>(Arrays.asList(
+ TableCapability.BATCH_READ,
+ TableCapability.BATCH_WRITE,
+ TableCapability.TRUNCATE));
@Override
public StructType schema() {
@@ -37,6 +45,11 @@ public StructType schema() {
public String name() {
return this.getClass().toString();
}
+
+ @Override
+ public Set capabilities() {
+ return CAPABILITIES;
+ }
}
abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
index 852c4546df885..7474f36c97f75 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java
@@ -17,10 +17,10 @@
package test.org.apache.spark.sql.sources.v2;
-import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.Table;
import org.apache.spark.sql.sources.v2.TableProvider;
import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
public class JavaSimpleDataSourceV2 implements TableProvider {
@@ -36,10 +36,10 @@ public InputPartition[] planInputPartitions() {
}
@Override
- public Table getTable(DataSourceOptions options) {
+ public Table getTable(CaseInsensitiveStringMap options) {
return new JavaSimpleBatchTable() {
@Override
- public ScanBuilder newScanBuilder(DataSourceOptions options) {
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new MyScanBuilder();
}
};
diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
index a36b0cfa6ff18..914af589384df 100644
--- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
+++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister
@@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly
org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly
org.apache.spark.sql.streaming.sources.FakeReadBothModes
org.apache.spark.sql.streaming.sources.FakeReadNeitherMode
-org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider
+org.apache.spark.sql.streaming.sources.FakeWriteOnly
org.apache.spark.sql.streaming.sources.FakeNoWrite
org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback
diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql
new file mode 100644
index 0000000000000..bc144d01cee64
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql
@@ -0,0 +1,27 @@
+-- Test tables
+CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET;
+CREATE table desc_temp2 (key int, val string) USING PARQUET;
+
+-- Simple Describe query
+DESC SELECT key, key + 1 as plusone FROM desc_temp1;
+DESC QUERY SELECT * FROM desc_temp2;
+DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key;
+DESC SELECT 10.00D as col1;
+DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE);
+DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2);
+DESC QUERY FROM desc_temp1 a SELECT *;
+
+
+-- Error cases.
+DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s;
+DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s;
+DESCRIBE INSERT INTO desc_temp1 values (1, 'val1');
+DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2;
+DESCRIBE
+ FROM desc_temp1 a
+ insert into desc_temp1 select *
+ insert into desc_temp2 select *;
+
+-- cleanup
+DROP TABLE desc_temp1;
+DROP TABLE desc_temp2;
diff --git a/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out
new file mode 100644
index 0000000000000..36cb314884779
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out
@@ -0,0 +1,171 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 16
+
+
+-- !query 0
+CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET
+-- !query 0 schema
+struct<>
+-- !query 0 output
+
+
+
+-- !query 1
+CREATE table desc_temp2 (key int, val string) USING PARQUET
+-- !query 1 schema
+struct<>
+-- !query 1 output
+
+
+
+-- !query 2
+DESC SELECT key, key + 1 as plusone FROM desc_temp1
+-- !query 2 schema
+struct
+-- !query 2 output
+key int column_comment
+plusone int
+
+
+-- !query 3
+DESC QUERY SELECT * FROM desc_temp2
+-- !query 3 schema
+struct
+-- !query 3 output
+key int
+val string
+
+
+-- !query 4
+DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key
+-- !query 4 schema
+struct
+-- !query 4 output
+key int column_comment
+count bigint
+
+
+-- !query 5
+DESC SELECT 10.00D as col1
+-- !query 5 schema
+struct
+-- !query 5 output
+col1 double
+
+
+-- !query 6
+DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE)
+-- !query 6 schema
+struct
+-- !query 6 output
+key double
+
+
+-- !query 7
+DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2)
+-- !query 7 schema
+struct
+-- !query 7 output
+col1 double
+col2 string
+
+
+-- !query 8
+DESC QUERY FROM desc_temp1 a SELECT *
+-- !query 8 schema
+struct
+-- !query 8 output
+key int column_comment
+val string
+
+
+-- !query 9
+DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s
+-- !query 9 schema
+struct<>
+-- !query 9 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+mismatched input 'AS' expecting {, '.'}(line 1, pos 12)
+
+== SQL ==
+DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s
+------------^^^
+
+
+-- !query 10
+DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s
+-- !query 10 schema
+struct<>
+-- !query 10 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+mismatched input 's' expecting {, '.'}(line 1, pos 20)
+
+== SQL ==
+DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s
+--------------------^^^
+
+
+-- !query 11
+DESCRIBE INSERT INTO desc_temp1 values (1, 'val1')
+-- !query 11 schema
+struct<>
+-- !query 11 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21)
+
+== SQL ==
+DESCRIBE INSERT INTO desc_temp1 values (1, 'val1')
+---------------------^^^
+
+
+-- !query 12
+DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2
+-- !query 12 schema
+struct<>
+-- !query 12 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21)
+
+== SQL ==
+DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2
+---------------------^^^
+
+
+-- !query 13
+DESCRIBE
+ FROM desc_temp1 a
+ insert into desc_temp1 select *
+ insert into desc_temp2 select *
+-- !query 13 schema
+struct<>
+-- !query 13 output
+org.apache.spark.sql.catalyst.parser.ParseException
+
+mismatched input 'insert' expecting {, '(', ',', 'SELECT', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'JOIN', 'CROSS', 'INNER', 'LEFT', 'RIGHT', 'FULL', 'NATURAL', 'PIVOT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE', 'ANTI'}(line 3, pos 5)
+
+== SQL ==
+DESCRIBE
+ FROM desc_temp1 a
+ insert into desc_temp1 select *
+-----^^^
+ insert into desc_temp2 select *
+
+
+-- !query 14
+DROP TABLE desc_temp1
+-- !query 14 schema
+struct<>
+-- !query 14 output
+
+
+
+-- !query 15
+DROP TABLE desc_temp2
+-- !query 15 schema
+struct<>
+-- !query 15 output
+
diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0
new file mode 100644
index 0000000000000..9c1e3021c3ead
--- /dev/null
+++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0
@@ -0,0 +1,2 @@
+v1
+{"nextBatchWatermarkMs":0}
\ No newline at end of file
diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata
new file mode 100644
index 0000000000000..3071b0dfc550b
--- /dev/null
+++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata
@@ -0,0 +1 @@
+{"id":"09be7fb3-49d8-48a6-840d-e9c2ad92a898"}
\ No newline at end of file
diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0
new file mode 100644
index 0000000000000..a0a567631fd14
--- /dev/null
+++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0
@@ -0,0 +1,3 @@
+v1
+{"batchWatermarkMs":0,"batchTimestampMs":1549649384149,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"200"}}
+0
\ No newline at end of file
diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet
new file mode 100644
index 0000000000000..1b2919b25c381
Binary files /dev/null and b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet differ
diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0
new file mode 100644
index 0000000000000..79768f89d6eca
--- /dev/null
+++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0
@@ -0,0 +1,2 @@
+v1
+{"path":"file://TEMPDIR/output%20%25@%23output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet","size":404,"isDir":false,"modificationTime":1549649385000,"blockReplication":1,"blockSize":33554432,"action":"add"}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 54342b691109d..e46802f69ed67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -334,83 +334,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo
test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") {
withTempDir { dir =>
val tempDir = new File(dir, "files").getCanonicalPath
- // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
- withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
- // write path
- Seq("csv", "json", "parquet", "orc").foreach { format =>
- var msg = intercept[AnalysisException] {
- sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
- }.getMessage
- assert(msg.contains("Cannot save interval data type into external storage."))
-
- msg = intercept[AnalysisException] {
- spark.udf.register("testType", () => new IntervalData())
- sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support calendarinterval data type."))
+ Seq(true).foreach { useV1 =>
+ val useV1List = if (useV1) {
+ "orc"
+ } else {
+ ""
}
+ def errorMessage(format: String, isWrite: Boolean): String = {
+ if (isWrite && (useV1 || format != "orc")) {
+ "cannot save interval data type into external storage."
+ } else {
+ s"$format data source does not support calendarinterval data type."
+ }
+ }
+
+ withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
+ // write path
+ Seq("csv", "json", "parquet", "orc").foreach { format =>
+ var msg = intercept[AnalysisException] {
+ sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir)
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true)))
+ }
- // read path
- Seq("parquet", "csv").foreach { format =>
- var msg = intercept[AnalysisException] {
- val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
- spark.range(1).write.format(format).mode("overwrite").save(tempDir)
- spark.read.schema(schema).format(format).load(tempDir).collect()
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support calendarinterval data type."))
-
- msg = intercept[AnalysisException] {
- val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
- spark.range(1).write.format(format).mode("overwrite").save(tempDir)
- spark.read.schema(schema).format(format).load(tempDir).collect()
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support calendarinterval data type."))
+ // read path
+ Seq("parquet", "csv").foreach { format =>
+ var msg = intercept[AnalysisException] {
+ val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil)
+ spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+ spark.read.schema(schema).format(format).load(tempDir).collect()
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
+
+ msg = intercept[AnalysisException] {
+ val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil)
+ spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+ spark.read.schema(schema).format(format).load(tempDir).collect()
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false)))
+ }
}
}
}
}
test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") {
- // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well.
- withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc",
- SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") {
- withTempDir { dir =>
- val tempDir = new File(dir, "files").getCanonicalPath
-
- Seq("parquet", "csv", "orc").foreach { format =>
- // write path
- var msg = intercept[AnalysisException] {
- sql("select null").write.format(format).mode("overwrite").save(tempDir)
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support null data type."))
-
- msg = intercept[AnalysisException] {
- spark.udf.register("testType", () => new NullData())
- sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support null data type."))
-
- // read path
- msg = intercept[AnalysisException] {
- val schema = StructType(StructField("a", NullType, true) :: Nil)
- spark.range(1).write.format(format).mode("overwrite").save(tempDir)
- spark.read.schema(schema).format(format).load(tempDir).collect()
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support null data type."))
-
- msg = intercept[AnalysisException] {
- val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
- spark.range(1).write.format(format).mode("overwrite").save(tempDir)
- spark.read.schema(schema).format(format).load(tempDir).collect()
- }.getMessage
- assert(msg.toLowerCase(Locale.ROOT)
- .contains(s"$format data source does not support null data type."))
+ Seq(true).foreach { useV1 =>
+ val useV1List = if (useV1) {
+ "orc"
+ } else {
+ ""
+ }
+ def errorMessage(format: String): String = {
+ s"$format data source does not support null data type."
+ }
+ withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List,
+ SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) {
+ withTempDir { dir =>
+ val tempDir = new File(dir, "files").getCanonicalPath
+
+ Seq("parquet", "csv", "orc").foreach { format =>
+ // write path
+ var msg = intercept[AnalysisException] {
+ sql("select null").write.format(format).mode("overwrite").save(tempDir)
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT)
+ .contains(errorMessage(format)))
+
+ msg = intercept[AnalysisException] {
+ spark.udf.register("testType", () => new NullData())
+ sql("select testType()").write.format(format).mode("overwrite").save(tempDir)
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT)
+ .contains(errorMessage(format)))
+
+ // read path
+ msg = intercept[AnalysisException] {
+ val schema = StructType(StructField("a", NullType, true) :: Nil)
+ spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+ spark.read.schema(schema).format(format).load(tempDir).collect()
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT)
+ .contains(errorMessage(format)))
+
+ msg = intercept[AnalysisException] {
+ val schema = StructType(StructField("a", new NullUDT(), true) :: Nil)
+ spark.range(1).write.format(format).mode("overwrite").save(tempDir)
+ spark.read.schema(schema).format(format).load(tempDir).collect()
+ }.getMessage
+ assert(msg.toLowerCase(Locale.ROOT)
+ .contains(errorMessage(format)))
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index 24b312348bd67..62f3f98bf28ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
import org.apache.spark.sql.execution.HiveResult.hiveResultString
-import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand}
+import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
@@ -277,7 +277,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
- case _: DescribeTableCommand | _: DescribeColumnCommand => true
+ case _: DescribeCommandBase | _: DescribeColumnCommand => true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 9f33feb1950c7..881268440ccd7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -234,6 +234,9 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
delegate.parseFunctionIdentifier(sqlText)
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] =
+ delegate.parseMultipartIdentifier(sqlText)
+
override def parseTableSchema(sqlText: String): StructType =
delegate.parseTableSchema(sqlText)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
index 31b9bcdafbab8..be3d0794d4036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala
@@ -215,19 +215,6 @@ class SparkSqlParserSuite extends AnalysisTest {
"no viable alternative at input")
}
- test("create table using - schema") {
- assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet",
- createTableUsing(
- table = "my_tab",
- schema = (new StructType)
- .add("a", IntegerType, nullable = true, "test")
- .add("b", StringType)
- )
- )
- intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet",
- "no viable alternative at input")
- }
-
test("create view as insert into table") {
// Single insert query
intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)",
@@ -240,15 +227,20 @@ class SparkSqlParserSuite extends AnalysisTest {
}
test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") {
+ assertEqual("describe t",
+ DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false))
assertEqual("describe table t",
- DescribeTableCommand(
- TableIdentifier("t"), Map.empty, isExtended = false))
+ DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false))
assertEqual("describe table extended t",
- DescribeTableCommand(
- TableIdentifier("t"), Map.empty, isExtended = true))
+ DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true))
assertEqual("describe table formatted t",
- DescribeTableCommand(
- TableIdentifier("t"), Map.empty, isExtended = true))
+ DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true))
+ }
+
+ test("describe query") {
+ val query = "SELECT * FROM t"
+ assertEqual("DESCRIBE QUERY " + query, DescribeQueryCommand(parser.parsePlan(query)))
+ assertEqual("DESCRIBE " + query, DescribeQueryCommand(parser.parsePlan(query)))
}
test("describe table column") {
@@ -387,4 +379,12 @@ class SparkSqlParserSuite extends AnalysisTest {
"INSERT INTO tbl2 SELECT * WHERE jt.id > 4",
"Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+")
}
+
+ test("database and schema tokens are interchangeable") {
+ assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo"))
+ assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo"))
+ assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')",
+ parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')"))
+ assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
index c36872a6a5289..86874b9817c20 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType}
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.util.Utils
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
index e0ccae15f1d05..0dd11c1e518e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala
@@ -32,13 +32,12 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.JsonTuple
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan}
-import org.apache.spark.sql.catalyst.plans.logical.{Project, ScriptTransformation}
+import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan, Project, ScriptTransformation}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
class DDLParserSuite extends PlanTest with SharedSQLContext {
@@ -415,173 +414,28 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything")))
}
- test("create table - with partitioned by") {
- val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " +
- "USING parquet PARTITIONED BY (a)"
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("my_tab"),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty,
- schema = new StructType()
- .add("a", IntegerType, nullable = true, "test")
- .add("b", StringType),
- provider = Some("parquet"),
- partitionColumnNames = Seq("a")
- )
-
- parser.parsePlan(query) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $query")
+ test("Duplicate clauses - create hive table") {
+ def createTableHeader(duplicateClause: String): String = {
+ s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause"
}
- }
-
- test("create table - with bucket") {
- val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " +
- "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS"
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("my_tab"),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty,
- schema = new StructType().add("a", IntegerType).add("b", StringType),
- provider = Some("parquet"),
- bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b")))
- )
-
- parser.parsePlan(query) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $query")
- }
- }
-
- test("create table - with comment") {
- val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("my_tab"),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty,
- schema = new StructType().add("a", IntegerType).add("b", StringType),
- provider = Some("parquet"),
- comment = Some("abc"))
-
- parser.parsePlan(sql) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
- }
- }
-
- test("create table - with table properties") {
- val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')"
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("my_tab"),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty,
- schema = new StructType().add("a", IntegerType).add("b", StringType),
- provider = Some("parquet"),
- properties = Map("test" -> "test"))
-
- parser.parsePlan(sql) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
- }
- }
-
- test("Duplicate clauses - create table") {
- def createTableHeader(duplicateClause: String, isNative: Boolean): String = {
- val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet"
- s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause"
- }
-
- Seq(true, false).foreach { isNative =>
- intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative),
- "Found duplicate clauses: TBLPROPERTIES")
- intercept(createTableHeader("LOCATION '/tmp/file'", isNative),
- "Found duplicate clauses: LOCATION")
- intercept(createTableHeader("COMMENT 'a table'", isNative),
- "Found duplicate clauses: COMMENT")
- intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative),
- "Found duplicate clauses: CLUSTERED BY")
- }
-
- // Only for native data source tables
- intercept(createTableHeader("PARTITIONED BY (b)", isNative = true),
- "Found duplicate clauses: PARTITIONED BY")
-
- // Only for Hive serde tables
- intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false),
+ intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"),
+ "Found duplicate clauses: TBLPROPERTIES")
+ intercept(createTableHeader("LOCATION '/tmp/file'"),
+ "Found duplicate clauses: LOCATION")
+ intercept(createTableHeader("COMMENT 'a table'"),
+ "Found duplicate clauses: COMMENT")
+ intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"),
+ "Found duplicate clauses: CLUSTERED BY")
+ intercept(createTableHeader("PARTITIONED BY (k int)"),
"Found duplicate clauses: PARTITIONED BY")
- intercept(createTableHeader("STORED AS parquet", isNative = false),
+ intercept(createTableHeader("STORED AS parquet"),
"Found duplicate clauses: STORED AS/BY")
intercept(
- createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false),
+ createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"),
"Found duplicate clauses: ROW FORMAT")
}
- test("create table - with location") {
- val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("my_tab"),
- tableType = CatalogTableType.EXTERNAL,
- storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))),
- schema = new StructType().add("a", IntegerType).add("b", StringType),
- provider = Some("parquet"))
-
- parser.parsePlan(v1) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $v1")
- }
-
- val v2 =
- """
- |CREATE TABLE my_tab(a INT, b STRING)
- |USING parquet
- |OPTIONS (path '/tmp/file')
- |LOCATION '/tmp/file'
- """.stripMargin
- val e = intercept[ParseException] {
- parser.parsePlan(v2)
- }
- assert(e.message.contains("you can only specify one of them."))
- }
-
- test("create table - byte length literal table name") {
- val sql = "CREATE TABLE 1m.2g(a INT) USING parquet"
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("2g", Some("1m")),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty,
- schema = new StructType().add("a", IntegerType),
- provider = Some("parquet"))
-
- parser.parsePlan(sql) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
- }
- }
-
test("insert overwrite directory") {
val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a"
parser.parsePlan(v1) match {
@@ -1032,64 +886,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
assert(e.contains("Found an empty partition key 'b'"))
}
- test("drop table") {
- val tableName1 = "db.tab"
- val tableName2 = "tab"
-
- val parsed = Seq(
- s"DROP TABLE $tableName1",
- s"DROP TABLE IF EXISTS $tableName1",
- s"DROP TABLE $tableName2",
- s"DROP TABLE IF EXISTS $tableName2",
- s"DROP TABLE $tableName2 PURGE",
- s"DROP TABLE IF EXISTS $tableName2 PURGE"
- ).map(parser.parsePlan)
-
- val expected = Seq(
- DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = false, isView = false,
- purge = false),
- DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = true, isView = false,
- purge = false),
- DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false,
- purge = false),
- DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false,
- purge = false),
- DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false,
- purge = true),
- DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false,
- purge = true))
-
- parsed.zip(expected).foreach { case (p, e) => comparePlans(p, e) }
- }
-
- test("drop view") {
- val viewName1 = "db.view"
- val viewName2 = "view"
-
- val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1")
- val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1")
- val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2")
- val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2")
-
- val expected1 =
- DropTableCommand(TableIdentifier("view", Option("db")), ifExists = false, isView = true,
- purge = false)
- val expected2 =
- DropTableCommand(TableIdentifier("view", Option("db")), ifExists = true, isView = true,
- purge = false)
- val expected3 =
- DropTableCommand(TableIdentifier("view", None), ifExists = false, isView = true,
- purge = false)
- val expected4 =
- DropTableCommand(TableIdentifier("view", None), ifExists = true, isView = true,
- purge = false)
-
- comparePlans(parsed1, expected1)
- comparePlans(parsed2, expected2)
- comparePlans(parsed3, expected3)
- comparePlans(parsed4, expected4)
- }
-
test("show columns") {
val sql1 = "SHOW COLUMNS FROM t1"
val sql2 = "SHOW COLUMNS IN db1.t1"
@@ -1165,84 +961,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext {
comparePlans(parsed, expected)
}
- test("support for other types in OPTIONS") {
- val sql =
- """
- |CREATE TABLE table_name USING json
- |OPTIONS (a 1, b 0.1, c TRUE)
- """.stripMargin
-
- val expectedTableDesc = CatalogTable(
- identifier = TableIdentifier("table_name"),
- tableType = CatalogTableType.MANAGED,
- storage = CatalogStorageFormat.empty.copy(
- properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true")
- ),
- schema = new StructType,
- provider = Some("json")
- )
-
- parser.parsePlan(sql) match {
- case CreateTable(tableDesc, _, None) =>
- assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
- case other =>
- fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
- s"got ${other.getClass.getName}: $sql")
- }
- }
-
- test("Test CTAS against data source tables") {
- val s1 =
- """
- |CREATE TABLE IF NOT EXISTS mydb.page_view
- |USING parquet
- |COMMENT 'This is the staging page view table'
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src
- """.stripMargin
-
- val s2 =
- """
- |CREATE TABLE IF NOT EXISTS mydb.page_view
- |USING parquet
- |LOCATION '/user/external/page_view'
- |COMMENT 'This is the staging page view table'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src
- """.stripMargin
-
- val s3 =
- """
- |CREATE TABLE IF NOT EXISTS mydb.page_view
- |USING parquet
- |COMMENT 'This is the staging page view table'
- |LOCATION '/user/external/page_view'
- |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
- |AS SELECT * FROM src
- """.stripMargin
-
- checkParsing(s1)
- checkParsing(s2)
- checkParsing(s3)
-
- def checkParsing(sql: String): Unit = {
- val (desc, exists) = extractTableDesc(sql)
- assert(exists)
- assert(desc.identifier.database == Some("mydb"))
- assert(desc.identifier.table == "page_view")
- assert(desc.storage.locationUri == Some(new URI("/user/external/page_view")))
- assert(desc.schema.isEmpty) // will be populated later when the table is actually created
- assert(desc.comment == Some("This is the staging page view table"))
- assert(desc.viewText.isEmpty)
- assert(desc.viewDefaultDatabase.isEmpty)
- assert(desc.viewQueryColumnNames.isEmpty)
- assert(desc.partitionColumnNames.isEmpty)
- assert(desc.provider == Some("parquet"))
- assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
- }
- }
-
test("Test CTAS #1") {
val s1 =
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
new file mode 100644
index 0000000000000..06f7332086372
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -0,0 +1,504 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.command
+
+import java.net.URI
+import java.util.Locale
+
+import org.apache.spark.sql.{AnalysisException, SaveMode}
+import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, TableCatalog, TestTableCatalog}
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan}
+import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution}
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class PlanResolutionSuite extends AnalysisTest {
+ import CatalystSqlParser._
+
+ private val orc2 = classOf[OrcDataSourceV2].getName
+
+ private val testCat: TableCatalog = {
+ val newCatalog = new TestTableCatalog
+ newCatalog.initialize("testcat", CaseInsensitiveStringMap.empty())
+ newCatalog
+ }
+
+ private val lookupCatalog: String => CatalogPlugin = {
+ case "testcat" =>
+ testCat
+ case name =>
+ throw new CatalogNotFoundException(s"No such catalog: $name")
+ }
+
+ def parseAndResolve(query: String): LogicalPlan = {
+ val newConf = conf.copy()
+ newConf.setConfString("spark.sql.default.catalog", "testcat")
+ DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query))
+ }
+
+ private def parseResolveCompare(query: String, expected: LogicalPlan): Unit =
+ comparePlans(parseAndResolve(query), expected, checkAnalysis = true)
+
+ private def extractTableDesc(sql: String): (CatalogTable, Boolean) = {
+ parseAndResolve(sql).collect {
+ case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore)
+ }.head
+ }
+
+ test("create table - with partitioned by") {
+ val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " +
+ "USING parquet PARTITIONED BY (a)"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType()
+ .add("a", IntegerType, nullable = true, "test")
+ .add("b", StringType),
+ provider = Some("parquet"),
+ partitionColumnNames = Seq("a")
+ )
+
+ parseAndResolve(query) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $query")
+ }
+ }
+
+ test("create table - partitioned by transforms") {
+ val transforms = Seq(
+ "bucket(16, b)", "years(ts)", "months(ts)", "days(ts)", "hours(ts)", "foo(a, 'bar', 34)",
+ "bucket(32, b), days(ts)")
+ transforms.foreach { transform =>
+ val query =
+ s"""
+ |CREATE TABLE my_tab(a INT, b STRING) USING parquet
+ |PARTITIONED BY ($transform)
+ """.stripMargin
+
+ val ae = intercept[AnalysisException] {
+ parseAndResolve(query)
+ }
+
+ assert(ae.message
+ .contains(s"Transforms cannot be converted to partition columns: $transform"))
+ }
+ }
+
+ test("create table - with bucket") {
+ val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " +
+ "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("a", IntegerType).add("b", StringType),
+ provider = Some("parquet"),
+ bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b")))
+ )
+
+ parseAndResolve(query) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $query")
+ }
+ }
+
+ test("create table - with comment") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("a", IntegerType).add("b", StringType),
+ provider = Some("parquet"),
+ comment = Some("abc"))
+
+ parseAndResolve(sql) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with table properties") {
+ val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("a", IntegerType).add("b", StringType),
+ provider = Some("parquet"),
+ properties = Map("test" -> "test"))
+
+ parseAndResolve(sql) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("create table - with location") {
+ val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("my_tab"),
+ tableType = CatalogTableType.EXTERNAL,
+ storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))),
+ schema = new StructType().add("a", IntegerType).add("b", StringType),
+ provider = Some("parquet"))
+
+ parseAndResolve(v1) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $v1")
+ }
+
+ val v2 =
+ """
+ |CREATE TABLE my_tab(a INT, b STRING)
+ |USING parquet
+ |OPTIONS (path '/tmp/file')
+ |LOCATION '/tmp/file'
+ """.stripMargin
+ val e = intercept[AnalysisException] {
+ parseAndResolve(v2)
+ }
+ assert(e.message.contains("you can only specify one of them."))
+ }
+
+ test("create table - byte length literal table name") {
+ val sql = "CREATE TABLE 1m.2g(a INT) USING parquet"
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("2g", Some("1m")),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("a", IntegerType),
+ provider = Some("parquet"))
+
+ parseAndResolve(sql) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("support for other types in OPTIONS") {
+ val sql =
+ """
+ |CREATE TABLE table_name USING json
+ |OPTIONS (a 1, b 0.1, c TRUE)
+ """.stripMargin
+
+ val expectedTableDesc = CatalogTable(
+ identifier = TableIdentifier("table_name"),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty.copy(
+ properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true")
+ ),
+ schema = new StructType,
+ provider = Some("json")
+ )
+
+ parseAndResolve(sql) match {
+ case CreateTable(tableDesc, _, None) =>
+ assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime))
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Test CTAS against data source tables") {
+ val s1 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val s2 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |LOCATION '/user/external/page_view'
+ |COMMENT 'This is the staging page view table'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val s3 =
+ """
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING parquet
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ checkParsing(s1)
+ checkParsing(s2)
+ checkParsing(s3)
+
+ def checkParsing(sql: String): Unit = {
+ val (desc, exists) = extractTableDesc(sql)
+ assert(exists)
+ assert(desc.identifier.database.contains("mydb"))
+ assert(desc.identifier.table == "page_view")
+ assert(desc.storage.locationUri.contains(new URI("/user/external/page_view")))
+ assert(desc.schema.isEmpty) // will be populated later when the table is actually created
+ assert(desc.comment.contains("This is the staging page view table"))
+ assert(desc.viewText.isEmpty)
+ assert(desc.viewDefaultDatabase.isEmpty)
+ assert(desc.viewQueryColumnNames.isEmpty)
+ assert(desc.partitionColumnNames.isEmpty)
+ assert(desc.provider.contains("parquet"))
+ assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2"))
+ }
+ }
+
+ test("Test v2 CreateTable with known catalog in identifier") {
+ val sql =
+ s"""
+ |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name (
+ | id bigint,
+ | description string,
+ | point struct)
+ |USING parquet
+ |COMMENT 'table comment'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |OPTIONS (path 's3://bucket/path/to/data', other 20)
+ """.stripMargin
+
+ val expectedProperties = Map(
+ "p1" -> "v1",
+ "p2" -> "v2",
+ "other" -> "20",
+ "provider" -> "parquet",
+ "location" -> "s3://bucket/path/to/data",
+ "comment" -> "table comment")
+
+ parseAndResolve(sql) match {
+ case create: CreateV2Table =>
+ assert(create.catalog.name == "testcat")
+ assert(create.tableName == Identifier.of(Array("mydb"), "table_name"))
+ assert(create.tableSchema == new StructType()
+ .add("id", LongType)
+ .add("description", StringType)
+ .add("point", new StructType().add("x", DoubleType).add("y", DoubleType)))
+ assert(create.partitioning.isEmpty)
+ assert(create.properties == expectedProperties)
+ assert(create.ignoreIfExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Test v2 CreateTable with data source v2 provider") {
+ val sql =
+ s"""
+ |CREATE TABLE IF NOT EXISTS mydb.page_view (
+ | id bigint,
+ | description string,
+ | point struct)
+ |USING $orc2
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ """.stripMargin
+
+ val expectedProperties = Map(
+ "p1" -> "v1",
+ "p2" -> "v2",
+ "provider" -> orc2,
+ "location" -> "/user/external/page_view",
+ "comment" -> "This is the staging page view table")
+
+ parseAndResolve(sql) match {
+ case create: CreateV2Table =>
+ assert(create.catalog.name == "testcat")
+ assert(create.tableName == Identifier.of(Array("mydb"), "page_view"))
+ assert(create.tableSchema == new StructType()
+ .add("id", LongType)
+ .add("description", StringType)
+ .add("point", new StructType().add("x", DoubleType).add("y", DoubleType)))
+ assert(create.partitioning.isEmpty)
+ assert(create.properties == expectedProperties)
+ assert(create.ignoreIfExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Test v2 CTAS with known catalog in identifier") {
+ val sql =
+ s"""
+ |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name
+ |USING parquet
+ |COMMENT 'table comment'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |OPTIONS (path 's3://bucket/path/to/data', other 20)
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val expectedProperties = Map(
+ "p1" -> "v1",
+ "p2" -> "v2",
+ "other" -> "20",
+ "provider" -> "parquet",
+ "location" -> "s3://bucket/path/to/data",
+ "comment" -> "table comment")
+
+ parseAndResolve(sql) match {
+ case ctas: CreateTableAsSelect =>
+ assert(ctas.catalog.name == "testcat")
+ assert(ctas.tableName == Identifier.of(Array("mydb"), "table_name"))
+ assert(ctas.properties == expectedProperties)
+ assert(ctas.writeOptions == Map("other" -> "20"))
+ assert(ctas.partitioning.isEmpty)
+ assert(ctas.ignoreIfExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("Test v2 CTAS with data source v2 provider") {
+ val sql =
+ s"""
+ |CREATE TABLE IF NOT EXISTS mydb.page_view
+ |USING $orc2
+ |COMMENT 'This is the staging page view table'
+ |LOCATION '/user/external/page_view'
+ |TBLPROPERTIES ('p1'='v1', 'p2'='v2')
+ |AS SELECT * FROM src
+ """.stripMargin
+
+ val expectedProperties = Map(
+ "p1" -> "v1",
+ "p2" -> "v2",
+ "provider" -> orc2,
+ "location" -> "/user/external/page_view",
+ "comment" -> "This is the staging page view table")
+
+ parseAndResolve(sql) match {
+ case ctas: CreateTableAsSelect =>
+ assert(ctas.catalog.name == "testcat")
+ assert(ctas.tableName == Identifier.of(Array("mydb"), "page_view"))
+ assert(ctas.properties == expectedProperties)
+ assert(ctas.writeOptions.isEmpty)
+ assert(ctas.partitioning.isEmpty)
+ assert(ctas.ignoreIfExists)
+
+ case other =>
+ fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," +
+ s"got ${other.getClass.getName}: $sql")
+ }
+ }
+
+ test("drop table") {
+ val tableName1 = "db.tab"
+ val tableIdent1 = TableIdentifier("tab", Option("db"))
+ val tableName2 = "tab"
+ val tableIdent2 = TableIdentifier("tab", None)
+
+ parseResolveCompare(s"DROP TABLE $tableName1",
+ DropTableCommand(tableIdent1, ifExists = false, isView = false, purge = false))
+ parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1",
+ DropTableCommand(tableIdent1, ifExists = true, isView = false, purge = false))
+ parseResolveCompare(s"DROP TABLE $tableName2",
+ DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = false))
+ parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2",
+ DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = false))
+ parseResolveCompare(s"DROP TABLE $tableName2 PURGE",
+ DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = true))
+ parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2 PURGE",
+ DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = true))
+ }
+
+ test("drop table in v2 catalog") {
+ val tableName1 = "testcat.db.tab"
+ val tableIdent1 = Identifier.of(Array("db"), "tab")
+ val tableName2 = "testcat.tab"
+ val tableIdent2 = Identifier.of(Array.empty, "tab")
+
+ parseResolveCompare(s"DROP TABLE $tableName1",
+ DropTable(testCat, tableIdent1, ifExists = false))
+ parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1",
+ DropTable(testCat, tableIdent1, ifExists = true))
+ parseResolveCompare(s"DROP TABLE $tableName2",
+ DropTable(testCat, tableIdent2, ifExists = false))
+ parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2",
+ DropTable(testCat, tableIdent2, ifExists = true))
+ }
+
+ test("drop view") {
+ val viewName1 = "db.view"
+ val viewIdent1 = TableIdentifier("view", Option("db"))
+ val viewName2 = "view"
+ val viewIdent2 = TableIdentifier("view")
+
+ parseResolveCompare(s"DROP VIEW $viewName1",
+ DropTableCommand(viewIdent1, ifExists = false, isView = true, purge = false))
+ parseResolveCompare(s"DROP VIEW IF EXISTS $viewName1",
+ DropTableCommand(viewIdent1, ifExists = true, isView = true, purge = false))
+ parseResolveCompare(s"DROP VIEW $viewName2",
+ DropTableCommand(viewIdent2, ifExists = false, isView = true, purge = false))
+ parseResolveCompare(s"DROP VIEW IF EXISTS $viewName2",
+ DropTableCommand(viewIdent2, ifExists = true, isView = true, purge = false))
+ }
+
+ test("drop view in v2 catalog") {
+ intercept[AnalysisException] {
+ parseAndResolve("DROP VIEW testcat.db.view")
+ }.getMessage.toLowerCase(Locale.ROOT).contains(
+ "view support in catalog has not been implemented")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
index f20aded169e44..2f5d5551c5df0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -219,6 +219,13 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext {
IsNotNull(attrInt))), None)
}
+ test("SPARK-26865 DataSourceV2Strategy should push normalized filters") {
+ val attrInt = 'cint.int
+ assertResult(Seq(IsNotNull(attrInt))) {
+ DataSourceStrategy.normalizeFilters(Seq(IsNotNull(attrInt.withName("CiNt"))), Seq(attrInt))
+ }
+ }
+
/**
* Translate the given Catalyst [[Expression]] into data source [[sources.Filter]]
* then verify against the given [[sources.Filter]].
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index cccd8e9ee8bd1..034454d21d7ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsR
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -58,7 +57,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext {
case PhysicalOperation(_, filters,
DataSourceV2Relation(orcTable: OrcTable, _, options)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
- val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava))
+ val scanBuilder = orcTable.newScanBuilder(options)
scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray)
val pushedFilters = scanBuilder.pushedFilters()
assert(pushedFilters.nonEmpty, "No filter is pushed down")
@@ -102,7 +101,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext {
case PhysicalOperation(_, filters,
DataSourceV2Relation(orcTable: OrcTable, _, options)) =>
assert(filters.nonEmpty, "No filter is analyzed from the given query")
- val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava))
+ val scanBuilder = orcTable.newScanBuilder(options)
scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray)
val pushedFilters = scanBuilder.pushedFilters()
if (noneSupported) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
index 4a695ac74c476..b4d92c3b2d2fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.orc
import java.io.File
+import org.apache.hadoop.fs.{Path, PathFilter}
+
import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.internal.SQLConf
@@ -30,6 +32,10 @@ case class OrcParData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
+class TestFileFilter extends PathFilter {
+ override def accept(path: Path): Boolean = path.getParent.getName != "p=2"
+}
+
abstract class OrcPartitionDiscoveryTest extends OrcTest {
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
@@ -226,6 +232,23 @@ abstract class OrcPartitionDiscoveryTest extends OrcTest {
}
}
}
+
+ test("SPARK-27162: handle pathfilter configuration correctly") {
+ withTempPath { dir =>
+ val path = dir.getCanonicalPath
+
+ val df = spark.range(2)
+ df.write.orc(path + "/p=1")
+ df.write.orc(path + "/p=2")
+ assert(spark.read.orc(path).count() === 4)
+
+ val extraOptions = Map(
+ "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName,
+ "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName
+ )
+ assert(spark.read.options(extraOptions).orc(path).count() === 2)
+ }
+ }
}
class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala
new file mode 100644
index 0000000000000..8a0450fce76a1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
+import org.apache.spark.sql.catalyst.plans.logical.Union
+import org.apache.spark.sql.execution.datasources.DataSource
+import org.apache.spark.sql.execution.streaming.{Offset, Source, StreamingRelation, StreamingRelationV2}
+import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.sources.v2.{Table, TableCapability, TableProvider}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class V2StreamingScanSupportCheckSuite extends SparkFunSuite with SharedSparkSession {
+ import TableCapability._
+
+ private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = {
+ StreamingRelationV2(FakeTableProvider, "fake", table, CaseInsensitiveStringMap.empty(),
+ FakeTableProvider.schema.toAttributes, v1Relation)(spark)
+ }
+
+ private def createStreamingRelationV1() = {
+ StreamingRelation(DataSource(spark, classOf[FakeStreamSourceProvider].getName))
+ }
+
+ test("check correct plan") {
+ val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None)
+ val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None)
+ val plan3 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ, CONTINUOUS_READ), None)
+ val plan4 = createStreamingRelationV1()
+
+ V2StreamingScanSupportCheck(Union(plan1, plan1))
+ V2StreamingScanSupportCheck(Union(plan2, plan2))
+ V2StreamingScanSupportCheck(Union(plan1, plan3))
+ V2StreamingScanSupportCheck(Union(plan2, plan3))
+ V2StreamingScanSupportCheck(Union(plan1, plan4))
+ V2StreamingScanSupportCheck(Union(plan3, plan4))
+ }
+
+ test("table without scan capability") {
+ val e = intercept[AnalysisException] {
+ V2StreamingScanSupportCheck(createStreamingRelation(CapabilityTable(), None))
+ }
+ assert(e.message.contains("does not support either micro-batch or continuous scan"))
+ }
+
+ test("mix micro-batch only and continuous only") {
+ val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None)
+ val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None)
+
+ val e = intercept[AnalysisException] {
+ V2StreamingScanSupportCheck(Union(plan1, plan2))
+ }
+ assert(e.message.contains(
+ "The streaming sources in a query do not have a common supported execution mode"))
+ }
+
+ test("mix continuous only and v1 relation") {
+ val plan1 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None)
+ val plan2 = createStreamingRelationV1()
+ val e = intercept[AnalysisException] {
+ V2StreamingScanSupportCheck(Union(plan1, plan2))
+ }
+ assert(e.message.contains(
+ "The streaming sources in a query do not have a common supported execution mode"))
+ }
+}
+
+private object FakeTableProvider extends TableProvider {
+ val schema = new StructType().add("i", "int")
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ throw new UnsupportedOperationException
+ }
+}
+
+private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
+ override def name(): String = "capability_test_table"
+ override def schema(): StructType = FakeTableProvider.schema
+ override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava
+}
+
+private class FakeStreamSourceProvider extends StreamSourceProvider {
+ override def sourceSchema(
+ sqlContext: SQLContext,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): (String, StructType) = {
+ "fake" -> FakeTableProvider.schema
+ }
+
+ override def createSource(
+ sqlContext: SQLContext,
+ metadataPath: String,
+ schema: Option[StructType],
+ providerName: String,
+ parameters: Map[String, String]): Source = {
+ new Source {
+ override def schema: StructType = FakeTableProvider.schema
+ override def getOffset: Option[Offset] = {
+ throw new UnsupportedOperationException
+ }
+ override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+ throw new UnsupportedOperationException
+ }
+ override def stop(): Unit = {}
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
index 3bc36ce55d902..3ead91fcf712a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala
@@ -22,6 +22,8 @@ import scala.language.implicitConversions
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
@@ -36,7 +38,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
test("directly add data in Append output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
- val sink = new MemorySink(schema, OutputMode.Append)
+ val sink = new MemorySink
+ val addBatch = addBatchFunc(sink, false) _
// Before adding data, check output
assert(sink.latestBatchId === None)
@@ -44,25 +47,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
checkAnswer(sink.allData, Seq.empty)
// Add batch 0 and check outputs
- sink.addBatch(0, 1 to 3)
+ addBatch(0, 1 to 3)
assert(sink.latestBatchId === Some(0))
checkAnswer(sink.latestBatchData, 1 to 3)
checkAnswer(sink.allData, 1 to 3)
// Add batch 1 and check outputs
- sink.addBatch(1, 4 to 6)
+ addBatch(1, 4 to 6)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data
// Re-add batch 1 with different data, should not be added and outputs should not be changed
- sink.addBatch(1, 7 to 9)
+ addBatch(1, 7 to 9)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 1 to 6)
// Add batch 2 and check outputs
- sink.addBatch(2, 7 to 9)
+ addBatch(2, 7 to 9)
assert(sink.latestBatchId === Some(2))
checkAnswer(sink.latestBatchData, 7 to 9)
checkAnswer(sink.allData, 1 to 9)
@@ -70,7 +73,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
test("directly add data in Update output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
- val sink = new MemorySink(schema, OutputMode.Update)
+ val sink = new MemorySink
+ val addBatch = addBatchFunc(sink, false) _
// Before adding data, check output
assert(sink.latestBatchId === None)
@@ -78,25 +82,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
checkAnswer(sink.allData, Seq.empty)
// Add batch 0 and check outputs
- sink.addBatch(0, 1 to 3)
+ addBatch(0, 1 to 3)
assert(sink.latestBatchId === Some(0))
checkAnswer(sink.latestBatchData, 1 to 3)
checkAnswer(sink.allData, 1 to 3)
// Add batch 1 and check outputs
- sink.addBatch(1, 4 to 6)
+ addBatch(1, 4 to 6)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data
// Re-add batch 1 with different data, should not be added and outputs should not be changed
- sink.addBatch(1, 7 to 9)
+ addBatch(1, 7 to 9)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 1 to 6)
// Add batch 2 and check outputs
- sink.addBatch(2, 7 to 9)
+ addBatch(2, 7 to 9)
assert(sink.latestBatchId === Some(2))
checkAnswer(sink.latestBatchData, 7 to 9)
checkAnswer(sink.allData, 1 to 9)
@@ -104,7 +108,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
test("directly add data in Complete output mode") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
- val sink = new MemorySink(schema, OutputMode.Complete)
+ val sink = new MemorySink
+ val addBatch = addBatchFunc(sink, true) _
// Before adding data, check output
assert(sink.latestBatchId === None)
@@ -112,25 +117,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
checkAnswer(sink.allData, Seq.empty)
// Add batch 0 and check outputs
- sink.addBatch(0, 1 to 3)
+ addBatch(0, 1 to 3)
assert(sink.latestBatchId === Some(0))
checkAnswer(sink.latestBatchData, 1 to 3)
checkAnswer(sink.allData, 1 to 3)
// Add batch 1 and check outputs
- sink.addBatch(1, 4 to 6)
+ addBatch(1, 4 to 6)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 4 to 6) // new data should replace old data
// Re-add batch 1 with different data, should not be added and outputs should not be changed
- sink.addBatch(1, 7 to 9)
+ addBatch(1, 7 to 9)
assert(sink.latestBatchId === Some(1))
checkAnswer(sink.latestBatchData, 4 to 6)
checkAnswer(sink.allData, 4 to 6)
// Add batch 2 and check outputs
- sink.addBatch(2, 7 to 9)
+ addBatch(2, 7 to 9)
assert(sink.latestBatchId === Some(2))
checkAnswer(sink.latestBatchData, 7 to 9)
checkAnswer(sink.allData, 7 to 9)
@@ -211,18 +216,19 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
test("MemoryPlan statistics") {
implicit val schema = new StructType().add(new StructField("value", IntegerType))
- val sink = new MemorySink(schema, OutputMode.Append)
- val plan = new MemoryPlan(sink)
+ val sink = new MemorySink
+ val addBatch = addBatchFunc(sink, false) _
+ val plan = new MemoryPlan(sink, schema.toAttributes)
// Before adding data, check output
checkAnswer(sink.allData, Seq.empty)
assert(plan.stats.sizeInBytes === 0)
- sink.addBatch(0, 1 to 3)
+ addBatch(0, 1 to 3)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 36)
- sink.addBatch(1, 4 to 6)
+ addBatch(1, 4 to 6)
plan.invalidateStatsCache()
assert(plan.stats.sizeInBytes === 72)
}
@@ -285,6 +291,50 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
}
}
+ test("data writer") {
+ val partition = 1234
+ val writer = new MemoryDataWriter(
+ partition, new StructType().add("i", "int"))
+ writer.write(InternalRow(1))
+ writer.write(InternalRow(2))
+ writer.write(InternalRow(44))
+ val msg = writer.commit()
+ assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
+ assert(msg.partition == partition)
+
+ // Buffer should be cleared, so repeated commits should give empty.
+ assert(writer.commit().data.isEmpty)
+ }
+
+ test("streaming writer") {
+ val sink = new MemorySink
+ val write = new MemoryStreamingWrite(
+ sink, new StructType().add("i", "int"), needTruncate = false)
+ write.commit(0,
+ Array(
+ MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
+ MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
+ MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
+ ))
+ assert(sink.latestBatchId.contains(0))
+ assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
+ write.commit(19,
+ Array(
+ MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
+ MemoryWriterCommitMessage(0, Seq(Row(33)))
+ ))
+ assert(sink.latestBatchId.contains(19))
+ assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
+
+ assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
+ }
+
+ private def addBatchFunc(sink: MemorySink, needTruncate: Boolean)(
+ batchId: Long,
+ vals: Seq[Int]): Unit = {
+ sink.write(batchId, needTruncate, vals.map(Row(_)).toArray)
+ }
+
private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = {
checkAnswer(
sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
deleted file mode 100644
index 61857365ac989..0000000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala
+++ /dev/null
@@ -1,66 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming
-
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.streaming.sources._
-import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
-import org.apache.spark.sql.types.StructType
-
-class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
- test("data writer") {
- val partition = 1234
- val writer = new MemoryDataWriter(
- partition, OutputMode.Append(), new StructType().add("i", "int"))
- writer.write(InternalRow(1))
- writer.write(InternalRow(2))
- writer.write(InternalRow(44))
- val msg = writer.commit()
- assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
- assert(msg.partition == partition)
-
- // Buffer should be cleared, so repeated commits should give empty.
- assert(writer.commit().data.isEmpty)
- }
-
- test("streaming writer") {
- val sink = new MemorySinkV2
- val writeSupport = new MemoryStreamingWriteSupport(
- sink, OutputMode.Append(), new StructType().add("i", "int"))
- writeSupport.commit(0,
- Array(
- MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
- MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
- MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
- ))
- assert(sink.latestBatchId.contains(0))
- assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
- writeSupport.commit(19,
- Array(
- MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
- MemoryWriterCommitMessage(0, Seq(Row(33)))
- ))
- assert(sink.latestBatchId.contains(19))
- assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))
-
- assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
index d0418f893143e..ef88598fcb11b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala
@@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.streaming.Offset
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ManualClock
class RateStreamProviderSuite extends StreamTest {
@@ -39,7 +39,7 @@ class RateStreamProviderSuite extends StreamTest {
import testImplicits._
case class AdvanceRateManualClock(seconds: Long) extends AddData {
- override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+ override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = {
assert(query.nonEmpty)
val rateSource = query.get.logicalPlan.collect {
case r: StreamingDataSourceV2Relation
@@ -135,7 +135,7 @@ class RateStreamProviderSuite extends StreamTest {
withTempDir { temp =>
val stream = new RateStreamMicroBatchStream(
rowsPerSecond = 100,
- options = new DataSourceOptions(Map("useManualClock" -> "true").asJava),
+ options = new CaseInsensitiveStringMap(Map("useManualClock" -> "true").asJava),
checkpointLocation = temp.getCanonicalPath)
stream.clock.asInstanceOf[ManualClock].advance(100000)
val startOffset = stream.initialOffset()
@@ -154,7 +154,7 @@ class RateStreamProviderSuite extends StreamTest {
withTempDir { temp =>
val stream = new RateStreamMicroBatchStream(
rowsPerSecond = 20,
- options = DataSourceOptions.empty(),
+ options = CaseInsensitiveStringMap.empty(),
checkpointLocation = temp.getCanonicalPath)
val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L))
val readerFactory = stream.createReaderFactory()
@@ -173,7 +173,7 @@ class RateStreamProviderSuite extends StreamTest {
val stream = new RateStreamMicroBatchStream(
rowsPerSecond = 33,
numPartitions = 11,
- options = DataSourceOptions.empty(),
+ options = CaseInsensitiveStringMap.empty(),
checkpointLocation = temp.getCanonicalPath)
val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L))
val readerFactory = stream.createReaderFactory()
@@ -305,12 +305,11 @@ class RateStreamProviderSuite extends StreamTest {
.load()
}
assert(exception.getMessage.contains(
- "rate source does not support user-specified schema"))
+ "RateStreamProvider source does not support user-specified schema"))
}
test("continuous data") {
- val stream = new RateStreamContinuousStream(
- rowsPerSecond = 20, numPartitions = 2, options = DataSourceOptions.empty())
+ val stream = new RateStreamContinuousStream(rowsPerSecond = 20, numPartitions = 2)
val partitions = stream.planInputPartitions(stream.initialOffset)
val readerFactory = stream.createContinuousReaderFactory()
assert(partitions.size == 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
index 33c65d784fba6..3c451e0538721 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala
@@ -35,11 +35,11 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.sources.v2.DataSourceOptions
-import org.apache.spark.sql.sources.v2.reader.streaming.Offset
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach {
@@ -55,7 +55,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
private var serverThread: ServerThread = null
case class AddSocketData(data: String*) extends AddData {
- override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+ override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = {
require(
query.nonEmpty,
"Cannot add data when there is no query for finding the active socket source")
@@ -176,13 +176,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
test("params not given") {
val provider = new TextSocketSourceProvider
intercept[AnalysisException] {
- provider.getTable(new DataSourceOptions(Map.empty[String, String].asJava))
+ provider.getTable(CaseInsensitiveStringMap.empty())
}
intercept[AnalysisException] {
- provider.getTable(new DataSourceOptions(Map("host" -> "localhost").asJava))
+ provider.getTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava))
}
intercept[AnalysisException] {
- provider.getTable(new DataSourceOptions(Map("port" -> "1234").asJava))
+ provider.getTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava))
}
}
@@ -190,7 +190,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
val provider = new TextSocketSourceProvider
val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle")
intercept[AnalysisException] {
- provider.getTable(new DataSourceOptions(params.asJava))
+ provider.getTable(new CaseInsensitiveStringMap(params.asJava))
}
}
@@ -201,10 +201,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
StructField("area", StringType) :: Nil)
val params = Map("host" -> "localhost", "port" -> "1234")
val exception = intercept[UnsupportedOperationException] {
- provider.getTable(new DataSourceOptions(params.asJava), userSpecifiedSchema)
+ provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema)
}
assert(exception.getMessage.contains(
- "socket source does not support user-specified schema"))
+ "TextSocketSourceProvider source does not support user-specified schema"))
}
test("input row metrics") {
@@ -299,7 +299,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
host = "localhost",
port = serverThread.port,
numPartitions = 2,
- options = DataSourceOptions.empty())
+ options = CaseInsensitiveStringMap.empty())
val partitions = stream.planInputPartitions(stream.initialOffset())
assert(partitions.length == 2)
@@ -351,7 +351,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
host = "localhost",
port = serverThread.port,
numPartitions = 2,
- options = DataSourceOptions.empty())
+ options = CaseInsensitiveStringMap.empty())
stream.startOffset = TextSocketOffset(List(5, 5))
assertThrows[IllegalStateException] {
@@ -367,7 +367,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before
host = "localhost",
port = serverThread.port,
numPartitions = 2,
- options = new DataSourceOptions(Map("includeTimestamp" -> "true").asJava))
+ options = new CaseInsensitiveStringMap(Map("includeTimestamp" -> "true").asJava))
val partitions = stream.planInputPartitions(stream.initialOffset())
assert(partitions.size == 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala
index 4592a1663faed..60f1b32a41f05 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala
@@ -21,8 +21,8 @@ import org.apache.arrow.vector._
import org.apache.arrow.vector.complex._
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.ArrowColumnVector
import org.apache.spark.unsafe.types.UTF8String
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index e8062dbb91e35..4dd65385d548b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -31,9 +31,9 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.memory.MemoryMode
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.arrow.ArrowUtils
import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.CalendarInterval
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala
new file mode 100644
index 0000000000000..5b9071b59b9b0
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+import org.apache.spark.sql.catalog.v2.Identifier
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{LongType, StringType, StructType}
+
+class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
+
+ import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._
+
+ private val orc2 = classOf[OrcDataSourceV2].getName
+
+ before {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName)
+ spark.conf.set("spark.sql.default.catalog", "testcat")
+
+ val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data")
+ df.createOrReplaceTempView("source")
+ val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data")
+ df2.createOrReplaceTempView("source2")
+ }
+
+ after {
+ spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables()
+ spark.sql("DROP TABLE source")
+ }
+
+ test("CreateTable: use v2 plan because catalog is set") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty)
+ }
+
+ test("CreateTable: use v2 plan because provider is v2") {
+ spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING $orc2")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> orc2).asJava)
+ assert(table.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty)
+ }
+
+ test("CreateTable: fail if table exists") {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ // run a second create query that should fail
+ val exc = intercept[TableAlreadyExistsException] {
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string, id2 bigint) USING bar")
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+
+ // table should not have changed
+ val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ assert(table2.name == "testcat.table_name")
+ assert(table2.partitioning.isEmpty)
+ assert(table2.properties == Map("provider" -> "foo").asJava)
+ assert(table2.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ // check that the table is still empty
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty)
+ }
+
+ test("CreateTable: if not exists") {
+ spark.sql(
+ "CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING foo")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ spark.sql("CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING bar")
+
+ // table should not have changed
+ val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ assert(table2.name == "testcat.table_name")
+ assert(table2.partitioning.isEmpty)
+ assert(table2.properties == Map("provider" -> "foo").asJava)
+ assert(table2.schema == new StructType().add("id", LongType).add("data", StringType))
+
+ // check that the table is still empty
+ val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq.empty)
+ }
+
+ test("CreateTable: fail analysis when default catalog is needed but missing") {
+ val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog")
+ try {
+ conf.unsetConf("spark.sql.default.catalog")
+
+ val exc = intercept[AnalysisException] {
+ spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source")
+ }
+
+ assert(exc.getMessage.contains("No catalog specified for table"))
+ assert(exc.getMessage.contains("table_name"))
+ assert(exc.getMessage.contains("no default catalog is set"))
+
+ } finally {
+ conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog)
+ }
+ }
+
+ test("CreateTableAsSelect: use v2 plan because catalog is set") {
+ spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType()
+ .add("id", LongType, nullable = false)
+ .add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
+ }
+
+ test("CreateTableAsSelect: use v2 plan because provider is v2") {
+ spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> orc2).asJava)
+ assert(table.schema == new StructType()
+ .add("id", LongType, nullable = false)
+ .add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
+ }
+
+ test("CreateTableAsSelect: fail if table exists") {
+ spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType()
+ .add("id", LongType, nullable = false)
+ .add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
+
+ // run a second CTAS query that should fail
+ val exc = intercept[TableAlreadyExistsException] {
+ spark.sql(
+ "CREATE TABLE testcat.table_name USING bar AS SELECT id, data, id as id2 FROM source2")
+ }
+
+ assert(exc.getMessage.contains("table_name"))
+
+ // table should not have changed
+ val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+ assert(table2.name == "testcat.table_name")
+ assert(table2.partitioning.isEmpty)
+ assert(table2.properties == Map("provider" -> "foo").asJava)
+ assert(table2.schema == new StructType()
+ .add("id", LongType, nullable = false)
+ .add("data", StringType))
+
+ val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source"))
+ }
+
+ test("CreateTableAsSelect: if not exists") {
+ spark.sql(
+ "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source")
+
+ val testCatalog = spark.catalog("testcat").asTableCatalog
+ val table = testCatalog.loadTable(Identifier.of(Array(), "table_name"))
+
+ assert(table.name == "testcat.table_name")
+ assert(table.partitioning.isEmpty)
+ assert(table.properties == Map("provider" -> "foo").asJava)
+ assert(table.schema == new StructType()
+ .add("id", LongType, nullable = false)
+ .add("data", StringType))
+
+ val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
+
+ spark.sql(
+ "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source2")
+
+ // check that the table contains data from just the first CTAS
+ val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
+ checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source"))
+ }
+
+ test("CreateTableAsSelect: fail analysis when default catalog is needed but missing") {
+ val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog")
+ try {
+ conf.unsetConf("spark.sql.default.catalog")
+
+ val exc = intercept[AnalysisException] {
+ spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source")
+ }
+
+ assert(exc.getMessage.contains("No catalog specified for table"))
+ assert(exc.getMessage.contains("table_name"))
+ assert(exc.getMessage.contains("no default catalog is set"))
+
+ } finally {
+ conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog)
+ }
+ }
+
+ test("DropTable: basic") {
+ val tableName = "testcat.ns1.ns2.tbl"
+ val ident = Identifier.of(Array("ns1", "ns2"), "tbl")
+ sql(s"CREATE TABLE $tableName USING foo AS SELECT id, data FROM source")
+ assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === true)
+ sql(s"DROP TABLE $tableName")
+ assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === false)
+ }
+
+ test("DropTable: if exists") {
+ intercept[NoSuchTableException] {
+ sql(s"DROP TABLE testcat.db.notbl")
+ }
+ sql(s"DROP TABLE IF EXISTS testcat.db.notbl")
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 511fdfe5c23ac..379c9c4303cd6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -18,21 +18,27 @@
package org.apache.spark.sql.sources.v2
import java.io.File
+import java.util
+import java.util.OptionalLong
+
+import scala.collection.JavaConverters._
import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation}
import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.functions._
import org.apache.spark.sql.sources.{Filter, GreaterThan}
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.vectorized.ColumnarBatch
class DataSourceV2Suite extends QueryTest with SharedSQLContext {
@@ -182,6 +188,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
+ test ("statistics report data source") {
+ Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach {
+ cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ val logical = df.queryExecution.optimizedPlan.collect {
+ case d: DataSourceV2Relation => d
+ }.head
+
+ val statics = logical.computeStats()
+ assert(statics.rowCount.isDefined && statics.rowCount.get === 10,
+ "Row count statics should be reported by data source")
+ assert(statics.sizeInBytes === 80,
+ "Size in bytes statics should be reported by data source")
+ }
+ }
+ }
+
test("SPARK-23574: no shuffle exchange with single partition") {
val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
@@ -195,14 +219,14 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
- .option("path", path).save()
+ .option("path", path).mode("append").save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
spark.range(10).select('id, -'id))
- // test with different save modes
+ // default save mode is append
spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName)
- .option("path", path).mode("append").save()
+ .option("path", path).save()
checkAnswer(
spark.read.format(cls.getName).option("path", path).load(),
spark.range(10).union(spark.range(10)).select('id, -'id))
@@ -213,17 +237,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
spark.read.format(cls.getName).option("path", path).load(),
spark.range(5).select('id, -'id))
- spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
- .option("path", path).mode("ignore").save()
- checkAnswer(
- spark.read.format(cls.getName).option("path", path).load(),
- spark.range(5).select('id, -'id))
+ val e = intercept[AnalysisException] {
+ spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
+ .option("path", path).mode("ignore").save()
+ }
+ assert(e.message.contains("please use Append or Overwrite modes instead"))
- val e = intercept[Exception] {
+ val e2 = intercept[AnalysisException] {
spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName)
.option("path", path).mode("error").save()
}
- assert(e.getMessage.contains("data already exists"))
+ assert(e2.getMessage.contains("please use Append or Overwrite modes instead"))
// test transaction
val failingUdf = org.apache.spark.sql.functions.udf {
@@ -238,10 +262,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
// this input data will fail to read middle way.
val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j)
- val e2 = intercept[SparkException] {
+ val e3 = intercept[SparkException] {
input.write.format(cls.getName).option("path", path).mode("overwrite").save()
}
- assert(e2.getMessage.contains("Writing job aborted"))
+ assert(e3.getMessage.contains("Writing job aborted"))
// make sure we don't have partial data.
assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty)
}
@@ -330,7 +354,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
val options = df.queryExecution.optimizedPlan.collectFirst {
case d: DataSourceV2Relation => d.options
}.get
- assert(options.get(optionName).get == "false")
+ assert(options.get(optionName) === "false")
}
}
@@ -351,19 +375,16 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
- test("SPARK-25700: do not read schema when writing in other modes except append mode") {
- withTempPath { file =>
- val cls = classOf[SimpleWriteOnlyDataSource]
- val path = file.getCanonicalPath
- val df = spark.range(5).select('id as 'i, -'id as 'j)
- // non-append mode should not throw exception, as they don't access schema.
- df.write.format(cls.getName).option("path", path).mode("error").save()
- df.write.format(cls.getName).option("path", path).mode("overwrite").save()
- df.write.format(cls.getName).option("path", path).mode("ignore").save()
- // append mode will access schema and should throw exception.
- intercept[SchemaReadAttemptException] {
- df.write.format(cls.getName).option("path", path).mode("append").save()
- }
+ test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") {
+ withTempView("t1") {
+ val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load()
+ Seq(2, 3).toDF("a").createTempView("t1")
+ val df = t2.where("i < (select max(a) from t1)").select('i)
+ val subqueries = df.queryExecution.executedPlan.collect {
+ case p => p.subqueries
+ }.flatten
+ assert(subqueries.length == 1)
+ checkAnswer(df, (0 until 3).map(i => Row(i)))
}
}
}
@@ -389,11 +410,13 @@ object SimpleReaderFactory extends PartitionReaderFactory {
}
}
-abstract class SimpleBatchTable extends Table with SupportsBatchRead {
+abstract class SimpleBatchTable extends Table with SupportsRead {
override def schema(): StructType = new StructType().add("i", "int").add("j", "int")
override def name(): String = this.getClass.toString
+
+ override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava
}
abstract class SimpleScanBuilder extends ScanBuilder
@@ -416,8 +439,8 @@ class SimpleSinglePartitionSource extends TableProvider {
}
}
- override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable {
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder()
}
}
@@ -433,8 +456,8 @@ class SimpleDataSourceV2 extends TableProvider {
}
}
- override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable {
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder()
}
}
@@ -442,8 +465,8 @@ class SimpleDataSourceV2 extends TableProvider {
class AdvancedDataSourceV2 extends TableProvider {
- override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable {
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new AdvancedScanBuilder()
}
}
@@ -538,16 +561,16 @@ class SchemaRequiredDataSource extends TableProvider {
override def readSchema(): StructType = schema
}
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
throw new IllegalArgumentException("requires a user-supplied schema")
}
- override def getTable(options: DataSourceOptions, schema: StructType): Table = {
+ override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = {
val userGivenSchema = schema
new SimpleBatchTable {
override def schema(): StructType = userGivenSchema
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder(userGivenSchema)
}
}
@@ -567,8 +590,8 @@ class ColumnarDataSourceV2 extends TableProvider {
}
}
- override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable {
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder()
}
}
@@ -619,7 +642,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory {
}
}
-
class PartitionAwareDataSource extends TableProvider {
class MyScanBuilder extends SimpleScanBuilder
@@ -639,8 +661,8 @@ class PartitionAwareDataSource extends TableProvider {
override def outputPartitioning(): Partitioning = new MyPartitioning
}
- override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable {
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder()
}
}
@@ -679,7 +701,7 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m)
class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
new MyTable(options) {
override def schema(): StructType = {
throw new SchemaReadAttemptException("schema should not be read.")
@@ -687,3 +709,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource {
}
}
}
+
+class ReportStatisticsDataSource extends TableProvider {
+
+ class MyScanBuilder extends SimpleScanBuilder
+ with SupportsReportStatistics {
+ override def estimateStatistics(): Statistics = {
+ new Statistics {
+ override def sizeInBytes(): OptionalLong = OptionalLong.of(80)
+
+ override def numRows(): OptionalLong = OptionalLong.of(10)
+ }
+ }
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10))
+ }
+ }
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new MyScanBuilder
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala
index f903c17923d0f..0b1e3b5fb076d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala
@@ -33,8 +33,8 @@ class DataSourceV2UtilsSuite extends SparkFunSuite {
conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false")
conf.setConfString("spark.datasource.another.config.name", "123")
conf.setConfString(s"spark.datasource.$keyPrefix.", "123")
- val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance()
- val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf)
+ val source = new DataSourceV2WithSessionConfig
+ val confs = DataSourceV2Utils.extractSessionConfigs(source, conf)
assert(confs.size == 2)
assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0)
assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala
index fd19a48497fe6..e84c082128e1c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala
@@ -16,15 +16,18 @@
*/
package org.apache.spark.sql.sources.v2
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.{AnalysisException, QueryTest}
import org.apache.spark.sql.execution.datasources.FileFormat
-import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetTest}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.ScanBuilder
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 {
@@ -32,19 +35,22 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 {
override def shortName(): String = "parquet"
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
new DummyReadOnlyFileTable
}
}
-class DummyReadOnlyFileTable extends Table with SupportsBatchRead {
+class DummyReadOnlyFileTable extends Table with SupportsRead {
override def name(): String = "dummy"
override def schema(): StructType = StructType(Nil)
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
throw new AnalysisException("Dummy file reader")
}
+
+ override def capabilities(): java.util.Set[TableCapability] =
+ Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava
}
class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 {
@@ -53,18 +59,21 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 {
override def shortName(): String = "parquet"
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
new DummyWriteOnlyFileTable
}
}
-class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite {
+class DummyWriteOnlyFileTable extends Table with SupportsWrite {
override def name(): String = "dummy"
override def schema(): StructType = StructType(Nil)
- override def newWriteBuilder(options: DataSourceOptions): WriteBuilder =
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
throw new AnalysisException("Dummy file writer")
+
+ override def capabilities(): java.util.Set[TableCapability] =
+ Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava
}
class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
index daca65fd1ad2c..c9d2f1eef24bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources.v2
import java.io.{BufferedReader, InputStreamReader, IOException}
+import java.util
import scala.collection.JavaConverters._
@@ -25,12 +26,12 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
-import org.apache.spark.internal.config.SPECULATION_ENABLED
-import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
/**
@@ -38,8 +39,7 @@ import org.apache.spark.util.SerializableConfiguration
* Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`.
* Each job moves files from `target/_temporary/uniqueId/` to `target`.
*/
-class SimpleWritableDataSource extends DataSourceV2
- with TableProvider with SessionConfigSupport {
+class SimpleWritableDataSource extends TableProvider with SessionConfigSupport {
private val tableSchema = new StructType().add("i", "long").add("j", "long")
@@ -69,38 +69,26 @@ class SimpleWritableDataSource extends DataSourceV2
override def readSchema(): StructType = tableSchema
}
- class MyWriteBuilder(path: String) extends WriteBuilder with SupportsSaveMode {
+ class MyWriteBuilder(path: String) extends WriteBuilder with SupportsTruncate {
private var queryId: String = _
- private var mode: SaveMode = _
+ private var needTruncate = false
override def withQueryId(queryId: String): WriteBuilder = {
this.queryId = queryId
this
}
- override def mode(mode: SaveMode): WriteBuilder = {
- this.mode = mode
+ override def truncate(): WriteBuilder = {
+ this.needTruncate = true
this
}
override def buildForBatch(): BatchWrite = {
- assert(mode != null)
-
val hadoopPath = new Path(path)
val hadoopConf = SparkContext.getActive.get.hadoopConfiguration
val fs = hadoopPath.getFileSystem(hadoopConf)
- if (mode == SaveMode.ErrorIfExists) {
- if (fs.exists(hadoopPath)) {
- throw new RuntimeException("data already exists.")
- }
- }
- if (mode == SaveMode.Ignore) {
- if (fs.exists(hadoopPath)) {
- return null
- }
- }
- if (mode == SaveMode.Overwrite) {
+ if (needTruncate) {
fs.delete(hadoopPath, true)
}
@@ -142,22 +130,27 @@ class SimpleWritableDataSource extends DataSourceV2
}
}
- class MyTable(options: DataSourceOptions) extends SimpleBatchTable with SupportsBatchWrite {
- private val path = options.get("path").get()
+ class MyTable(options: CaseInsensitiveStringMap)
+ extends SimpleBatchTable with SupportsWrite {
+
+ private val path = options.get("path")
private val conf = SparkContext.getActive.get.hadoopConfiguration
override def schema(): StructType = tableSchema
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
new MyScanBuilder(new Path(path).toUri.toString, conf)
}
- override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = {
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
new MyWriteBuilder(path)
}
+
+ override def capabilities(): util.Set[TableCapability] =
+ Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava
}
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
new MyTable(options)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala
new file mode 100644
index 0000000000000..42c2db2539060
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2
+
+import java.util
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, TableCatalog, TableChange, TestTableCatalog}
+import org.apache.spark.sql.catalog.v2.expressions.Transform
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
+import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+// this is currently in the spark-sql module because the read and write API is not in catalyst
+// TODO(rdblue): when the v2 source API is in catalyst, merge with TestTableCatalog/InMemoryTable
+class TestInMemoryTableCatalog extends TableCatalog {
+ import CatalogV2Implicits._
+
+ private val tables: util.Map[Identifier, InMemoryTable] =
+ new ConcurrentHashMap[Identifier, InMemoryTable]()
+ private var _name: Option[String] = None
+
+ override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
+ _name = Some(name)
+ }
+
+ override def name: String = _name.get
+
+ override def listTables(namespace: Array[String]): Array[Identifier] = {
+ tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
+ }
+
+ override def loadTable(ident: Identifier): Table = {
+ Option(tables.get(ident)) match {
+ case Some(table) =>
+ table
+ case _ =>
+ throw new NoSuchTableException(ident)
+ }
+ }
+
+ override def createTable(
+ ident: Identifier,
+ schema: StructType,
+ partitions: Array[Transform],
+ properties: util.Map[String, String]): Table = {
+
+ if (tables.containsKey(ident)) {
+ throw new TableAlreadyExistsException(ident)
+ }
+
+ if (partitions.nonEmpty) {
+ throw new UnsupportedOperationException(
+ s"Catalog $name: Partitioned tables are not supported")
+ }
+
+ val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties)
+
+ tables.put(ident, table)
+
+ table
+ }
+
+ override def alterTable(ident: Identifier, changes: TableChange*): Table = {
+ Option(tables.get(ident)) match {
+ case Some(table) =>
+ val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes)
+ val schema = TestTableCatalog.applySchemaChanges(table.schema, changes)
+ val newTable = new InMemoryTable(table.name, schema, properties, table.data)
+
+ tables.put(ident, newTable)
+
+ newTable
+ case _ =>
+ throw new NoSuchTableException(ident)
+ }
+ }
+
+ override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined
+
+ def clearTables(): Unit = {
+ tables.clear()
+ }
+}
+
+/**
+ * A simple in-memory table. Rows are stored as a buffered group produced by each output task.
+ */
+private class InMemoryTable(
+ val name: String,
+ val schema: StructType,
+ override val properties: util.Map[String, String])
+ extends Table with SupportsRead with SupportsWrite {
+
+ def this(
+ name: String,
+ schema: StructType,
+ properties: util.Map[String, String],
+ data: Array[BufferedRows]) = {
+ this(name, schema, properties)
+ replaceData(data)
+ }
+
+ def rows: Seq[InternalRow] = data.flatMap(_.rows)
+
+ @volatile var data: Array[BufferedRows] = Array.empty
+
+ def replaceData(buffers: Array[BufferedRows]): Unit = synchronized {
+ data = buffers
+ }
+
+ override def capabilities: util.Set[TableCapability] = Set(
+ TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new ScanBuilder() {
+ def build(): Scan = new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]))
+ }
+ }
+
+ class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch {
+ override def readSchema(): StructType = schema
+
+ override def toBatch: Batch = this
+
+ override def planInputPartitions(): Array[InputPartition] = data
+
+ override def createReaderFactory(): PartitionReaderFactory = BufferedRowsReaderFactory
+ }
+
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
+ new WriteBuilder with SupportsTruncate {
+ private var shouldTruncate: Boolean = false
+
+ override def truncate(): WriteBuilder = {
+ shouldTruncate = true
+ this
+ }
+
+ override def buildForBatch(): BatchWrite = {
+ if (shouldTruncate) TruncateAndAppend else Append
+ }
+ }
+ }
+
+ private object TruncateAndAppend extends BatchWrite {
+ override def createBatchWriterFactory(): DataWriterFactory = {
+ BufferedRowsWriterFactory
+ }
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = {
+ replaceData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ }
+ }
+
+ private object Append extends BatchWrite {
+ override def createBatchWriterFactory(): DataWriterFactory = {
+ BufferedRowsWriterFactory
+ }
+
+ override def commit(messages: Array[WriterCommitMessage]): Unit = {
+ replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows]))
+ }
+
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {
+ }
+ }
+}
+
+private class BufferedRows extends WriterCommitMessage with InputPartition with Serializable {
+ val rows = new mutable.ArrayBuffer[InternalRow]()
+}
+
+private object BufferedRowsReaderFactory extends PartitionReaderFactory {
+ override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
+ new BufferedRowsReader(partition.asInstanceOf[BufferedRows])
+ }
+}
+
+private class BufferedRowsReader(partition: BufferedRows) extends PartitionReader[InternalRow] {
+ private var index: Int = -1
+
+ override def next(): Boolean = {
+ index += 1
+ index < partition.rows.length
+ }
+
+ override def get(): InternalRow = partition.rows(index)
+
+ override def close(): Unit = {}
+}
+
+private object BufferedRowsWriterFactory extends DataWriterFactory {
+ override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
+ new BufferWriter
+ }
+}
+
+private class BufferWriter extends DataWriter[InternalRow] {
+ private val buffer = new BufferedRows
+
+ override def write(row: InternalRow): Unit = buffer.rows.append(row.copy())
+
+ override def commit(): WriterCommitMessage = buffer
+
+ override def abort(): Unit = {}
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala
new file mode 100644
index 0000000000000..1d76ee34a0e0b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, OverwriteByExpression, OverwritePartitionsDynamic}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, V2WriteSupportCheck}
+import org.apache.spark.sql.sources.v2.TableCapability._
+import org.apache.spark.sql.types.{LongType, StringType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class V2WriteSupportCheckSuite extends AnalysisTest {
+
+ test("AppendData: check missing capabilities") {
+ val plan = AppendData.byName(
+ DataSourceV2Relation.create(CapabilityTable(), CaseInsensitiveStringMap.empty), TestRelation)
+
+ val exc = intercept[AnalysisException]{
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ assert(exc.getMessage.contains("does not support append in batch mode"))
+ }
+
+ test("AppendData: check correct capabilities") {
+ val plan = AppendData.byName(
+ DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty),
+ TestRelation)
+
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ test("Truncate: check missing capabilities") {
+ Seq(CapabilityTable(),
+ CapabilityTable(BATCH_WRITE),
+ CapabilityTable(TRUNCATE),
+ CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>
+
+ val plan = OverwriteByExpression.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
+ Literal(true))
+
+ val exc = intercept[AnalysisException]{
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ assert(exc.getMessage.contains("does not support truncate in batch mode"))
+ }
+ }
+
+ test("Truncate: check correct capabilities") {
+ Seq(CapabilityTable(BATCH_WRITE, TRUNCATE),
+ CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table =>
+
+ val plan = OverwriteByExpression.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
+ Literal(true))
+
+ V2WriteSupportCheck.apply(plan)
+ }
+ }
+
+ test("OverwriteByExpression: check missing capabilities") {
+ Seq(CapabilityTable(),
+ CapabilityTable(BATCH_WRITE),
+ CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table =>
+
+ val plan = OverwriteByExpression.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
+ EqualTo(AttributeReference("x", LongType)(), Literal(5)))
+
+ val exc = intercept[AnalysisException]{
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ assert(exc.getMessage.contains(
+ "does not support overwrite expression (`x` = 5) in batch mode"))
+ }
+ }
+
+ test("OverwriteByExpression: check correct capabilities") {
+ val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)
+ val plan = OverwriteByExpression.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation,
+ EqualTo(AttributeReference("x", LongType)(), Literal(5)))
+
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ test("OverwritePartitionsDynamic: check missing capabilities") {
+ Seq(CapabilityTable(),
+ CapabilityTable(BATCH_WRITE),
+ CapabilityTable(OVERWRITE_DYNAMIC)).foreach { table =>
+
+ val plan = OverwritePartitionsDynamic.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation)
+
+ val exc = intercept[AnalysisException] {
+ V2WriteSupportCheck.apply(plan)
+ }
+
+ assert(exc.getMessage.contains("does not support dynamic overwrite in batch mode"))
+ }
+ }
+
+ test("OverwritePartitionsDynamic: check correct capabilities") {
+ val table = CapabilityTable(BATCH_WRITE, OVERWRITE_DYNAMIC)
+ val plan = OverwritePartitionsDynamic.byName(
+ DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation)
+
+ V2WriteSupportCheck.apply(plan)
+ }
+}
+
+private object V2WriteSupportCheckSuite {
+ val schema: StructType = new StructType().add("id", LongType).add("data", StringType)
+}
+
+private case object TestRelation extends LeafNode with NamedRelation {
+ override def name: String = "source_relation"
+ override def output: Seq[AttributeReference] = V2WriteSupportCheckSuite.schema.toAttributes
+}
+
+private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
+ override def name(): String = "capability_test_table"
+ override def schema(): StructType = V2WriteSupportCheckSuite.schema
+ override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
index c696204cecc2c..a0a55c08ff018 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
index ed53def556cb8..619d118e20873 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.streaming
+import java.io.File
import java.util.Locale
import org.apache.hadoop.fs.Path
@@ -454,4 +455,27 @@ class FileStreamSinkSuite extends StreamTest {
}
}
}
+
+ test("special characters in output path") {
+ withTempDir { tempDir =>
+ val checkpointDir = new File(tempDir, "chk")
+ val outputDir = new File(tempDir, "output @#output")
+ val inputData = MemoryStream[Int]
+ inputData.addData(1, 2, 3)
+ val q = inputData.toDF()
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .format("parquet")
+ .start(outputDir.getCanonicalPath)
+ try {
+ q.processAllAvailable()
+ } finally {
+ q.stop()
+ }
+ // The "_spark_metadata" directory should be in "outputDir"
+ assert(outputDir.listFiles.map(_.getName).contains(FileStreamSink.metadataDir))
+ val outputDf = spark.read.parquet(outputDir.getCanonicalPath).as[Int]
+ checkDatasetUnorderly(outputDf, 1, 2, 3)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 9235c6d7c896f..0736c6ef00eed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap}
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._
import org.apache.spark.sql.streaming.util.StreamManualClock
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 659deb8cbb51e..f229b08a20aa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration
import org.scalatest.time.SpanSugar._
-import org.apache.spark.{SparkConf, SparkContext, TaskContext}
+import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
+import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink}
import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -876,8 +876,8 @@ class StreamSuite extends StreamTest {
query.awaitTermination()
}
- assert(e.getMessage.contains(providerClassName))
- assert(e.getMessage.contains("instantiated"))
+ TestUtils.assertExceptionMsg(e, providerClassName)
+ TestUtils.assertExceptionMsg(e, "instantiated")
}
}
@@ -1083,15 +1083,15 @@ class StreamSuite extends StreamTest {
test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " +
" to Dataset - use v2 sink") {
- testCurrentTimestampOnStreamingQuery(useV2Sink = true)
+ testCurrentTimestampOnStreamingQuery()
}
test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " +
" to Dataset - use v1 sink") {
- testCurrentTimestampOnStreamingQuery(useV2Sink = false)
+ testCurrentTimestampOnStreamingQuery()
}
- private def testCurrentTimestampOnStreamingQuery(useV2Sink: Boolean): Unit = {
+ private def testCurrentTimestampOnStreamingQuery(): Unit = {
val input = MemoryStream[Int]
val df = input.toDS().withColumn("cur_timestamp", lit(current_timestamp()))
@@ -1109,7 +1109,7 @@ class StreamSuite extends StreamTest {
var lastTimestamp = System.currentTimeMillis()
val currentDate = DateTimeUtils.millisToDays(lastTimestamp)
- testStream(df, useV2Sink = useV2Sink) (
+ testStream(df) (
AddData(input, 1),
CheckLastBatch { rows: Seq[Row] =>
lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 1)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index da496837e7a19..fc72c940b922a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -42,8 +42,9 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
-import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream}
import org.apache.spark.sql.streaming.StreamingQueryListener._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.{Clock, SystemClock, Utils}
@@ -86,7 +87,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}
protected val defaultTrigger = Trigger.ProcessingTime(0)
- protected val defaultUseV2Sink = false
/** How long to wait for an active stream to catch up when checking a result. */
val streamingTimeout = 10.seconds
@@ -126,7 +126,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
* the active query, and then return the source object the data was added, as well as the
* offset of added data.
*/
- def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset)
+ def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2)
}
/** A trait that can be extended when testing a source. */
@@ -137,7 +137,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"
- override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
+ override def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2) = {
(source, source.addData(data))
}
}
@@ -294,7 +294,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
/** Execute arbitrary code */
object Execute {
def apply(name: String)(func: StreamExecution => Any): AssertOnQuery =
- AssertOnQuery(query => { func(query); true }, "name")
+ AssertOnQuery(query => { func(query); true }, name)
def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func)
}
@@ -327,8 +327,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
*/
def testStream(
_stream: Dataset[_],
- outputMode: OutputMode = OutputMode.Append,
- useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
+ outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized {
import org.apache.spark.sql.streaming.util.StreamManualClock
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@@ -340,8 +339,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
var pos = 0
var currentStream: StreamExecution = null
var lastStream: StreamExecution = null
- val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
- val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode)
+ val awaiting = new mutable.HashMap[Int, OffsetV2]() // source index -> offset to wait for
+ val sink = new MemorySink
val resetConfValues = mutable.Map[String, Option[String]]()
val defaultCheckpointLocation =
Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -394,10 +393,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
}
def testState = {
- val sinkDebugString = sink match {
- case s: MemorySink => s.toDebugString
- case s: MemorySinkV2 => s.toDebugString
- }
+ val sinkDebugString = sink.toDebugString
+
s"""
|== Progress ==
|$testActions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 97dbb9b0360ec..3f304e9ec7788 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -21,7 +21,7 @@ import java.io.File
import java.util.{Locale, TimeZone}
import org.apache.commons.io.FileUtils
-import org.scalatest.{Assertions, BeforeAndAfterAll}
+import org.scalatest.Assertions
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.rdd.BlockRDD
@@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager}
+import org.apache.spark.sql.execution.streaming.sources.MemorySink
+import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index d00f2e3bf4d1a..5351d9cf7f190 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -180,7 +180,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
val listeners = (1 to 5).map(_ => new EventCollector)
try {
listeners.foreach(listener => spark.streams.addListener(listener))
- testStream(df, OutputMode.Append, useV2Sink = true)(
+ testStream(df, OutputMode.Append)(
StartStream(Trigger.Continuous(1000)),
StopStream,
AssertOnQuery { query =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index dc22e31678fa3..ec0be40528a45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -17,23 +17,26 @@
package org.apache.spark.sql.streaming
+import java.io.File
import java.util.concurrent.CountDownLatch
import scala.collection.mutable
+import org.apache.commons.io.FileUtils
import org.apache.commons.lang3.RandomStringUtils
+import org.apache.hadoop.fs.Path
import org.scalactic.TolerantNumerics
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.mockito.MockitoSugar
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkException, TestUtils}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
+import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.reader.InputPartition
@@ -495,7 +498,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
test("input row calculation with same V2 source used twice in self-union") {
val streamInput = MemoryStream[Int]
- testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)(
+ testStream(streamInput.toDF().union(streamInput.toDF()))(
AddData(streamInput, 1, 2, 3),
CheckAnswer(1, 1, 2, 2, 3, 3),
AssertOnQuery { q =>
@@ -516,7 +519,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
// relation, which breaks exchange reuse, as the optimizer will remove Project from one side.
// Here we manually add a useful Project, to trigger exchange reuse.
val streamDF = memoryStream.toDF().select('value + 0 as "v")
- testStream(streamDF.join(streamDF, "v"), useV2Sink = true)(
+ testStream(streamDF.join(streamDF, "v"))(
AddData(memoryStream, 1, 2, 3),
CheckAnswer(1, 2, 3),
check
@@ -553,7 +556,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
val streamInput1 = MemoryStream[Int]
val streamInput2 = MemoryStream[Int]
- testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)(
+ testStream(streamInput1.toDF().union(streamInput2.toDF()))(
AddData(streamInput1, 1, 2, 3),
CheckLastBatch(1, 2, 3),
AssertOnQuery { q =>
@@ -584,7 +587,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
val streamInput = MemoryStream[Int]
val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue")
- testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)(
+ testStream(streamInput.toDF().join(staticInputDF, "value"))(
AddData(streamInput, 1, 2, 3),
AssertOnQuery { q =>
q.processAllAvailable()
@@ -606,7 +609,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
val streamInput2 = MemoryStream[Int]
val staticInputDF2 = staticInputDF.union(staticInputDF).cache()
- testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)(
+ testStream(streamInput2.toDF().join(staticInputDF2, "value"))(
AddData(streamInput2, 1, 2, 3),
AssertOnQuery { q =>
q.processAllAvailable()
@@ -714,8 +717,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
q3.processAllAvailable()
}
assert(e.getCause.isInstanceOf[SparkException])
- assert(e.getCause.getCause.isInstanceOf[IllegalStateException])
- assert(e.getMessage.contains("StreamingQuery cannot be used in executors"))
+ assert(e.getCause.getCause.getCause.isInstanceOf[IllegalStateException])
+ TestUtils.assertExceptionMsg(e, "StreamingQuery cannot be used in executors")
} finally {
q1.stop()
q2.stop()
@@ -909,12 +912,195 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
)
- testStream(df, useV2Sink = true)(
+ testStream(df)(
StartStream(trigger = Trigger.Continuous(100)),
AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation"))
)
}
+ test("special characters in checkpoint path") {
+ withTempDir { tempDir =>
+ val checkpointDir = new File(tempDir, "chk @#chk")
+ val inputData = MemoryStream[Int]
+ inputData.addData(1)
+ val q = inputData.toDF()
+ .writeStream
+ .format("noop")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start()
+ try {
+ q.processAllAvailable()
+ assert(checkpointDir.listFiles().toList.nonEmpty)
+ } finally {
+ q.stop()
+ }
+ }
+ }
+
+ /**
+ * Copy the checkpoint generated by Spark 2.4.0 from test resource to `dir` to set up a legacy
+ * streaming checkpoint.
+ */
+ private def setUp2dot4dot0Checkpoint(dir: File): Unit = {
+ val input = getClass.getResource("/structured-streaming/escaped-path-2.4.0")
+ assert(input != null, "cannot find test resource '/structured-streaming/escaped-path-2.4.0'")
+ val inputDir = new File(input.toURI)
+
+ // Copy test files to tempDir so that we won't modify the original data.
+ FileUtils.copyDirectory(inputDir, dir)
+
+ // Spark 2.4 and earlier escaped the _spark_metadata path once
+ val legacySparkMetadataDir = new File(
+ dir,
+ new Path("output %@#output/_spark_metadata").toUri.toString)
+
+ // Migrate from legacy _spark_metadata directory to the new _spark_metadata directory.
+ // Ideally we should copy "_spark_metadata" directly like what the user is supposed to do to
+ // migrate to new version. However, in our test, "tempDir" will be different in each run and
+ // we need to fix the absolute path in the metadata to match "tempDir".
+ val sparkMetadata = FileUtils.readFileToString(new File(legacySparkMetadataDir, "0"), "UTF-8")
+ FileUtils.write(
+ new File(legacySparkMetadataDir, "0"),
+ sparkMetadata.replaceAll("TEMPDIR", dir.getCanonicalPath),
+ "UTF-8")
+ }
+
+ test("detect escaped path and report the migration guide") {
+ // Assert that the error message contains the migration conf, path and the legacy path.
+ def assertMigrationError(errorMessage: String, path: File, legacyPath: File): Unit = {
+ Seq(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key,
+ path.getCanonicalPath,
+ legacyPath.getCanonicalPath).foreach { msg =>
+ assert(errorMessage.contains(msg))
+ }
+ }
+
+ withTempDir { tempDir =>
+ setUp2dot4dot0Checkpoint(tempDir)
+
+ // Here are the paths we will use to create the query
+ val outputDir = new File(tempDir, "output %@#output")
+ val checkpointDir = new File(tempDir, "chk %@#chk")
+ val sparkMetadataDir = new File(tempDir, "output %@#output/_spark_metadata")
+
+ // The escaped paths used by Spark 2.4 and earlier.
+ // Spark 2.4 and earlier escaped the checkpoint path three times
+ val legacyCheckpointDir = new File(
+ tempDir,
+ new Path(new Path(new Path("chk %@#chk").toUri.toString).toUri.toString).toUri.toString)
+ // Spark 2.4 and earlier escaped the _spark_metadata path once
+ val legacySparkMetadataDir = new File(
+ tempDir,
+ new Path("output %@#output/_spark_metadata").toUri.toString)
+
+ // Reading a file sink output in a batch query should detect the legacy _spark_metadata
+ // directory and throw an error
+ val e = intercept[SparkException] {
+ spark.read.load(outputDir.getCanonicalPath).as[Int]
+ }
+ assertMigrationError(e.getMessage, sparkMetadataDir, legacySparkMetadataDir)
+
+ // Restarting the streaming query should detect the legacy _spark_metadata directory and throw
+ // an error
+ val inputData = MemoryStream[Int]
+ val e2 = intercept[SparkException] {
+ inputData.toDF()
+ .writeStream
+ .format("parquet")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start(outputDir.getCanonicalPath)
+ }
+ assertMigrationError(e2.getMessage, sparkMetadataDir, legacySparkMetadataDir)
+
+ // Move "_spark_metadata" to fix the file sink and test the checkpoint path.
+ FileUtils.moveDirectory(legacySparkMetadataDir, sparkMetadataDir)
+
+ // Restarting the streaming query should detect the legacy checkpoint path and throw an error
+ val e3 = intercept[SparkException] {
+ inputData.toDF()
+ .writeStream
+ .format("parquet")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start(outputDir.getCanonicalPath)
+ }
+ assertMigrationError(e3.getMessage, checkpointDir, legacyCheckpointDir)
+
+ // Fix the checkpoint path and verify that the user can migrate the issue by moving files.
+ FileUtils.moveDirectory(legacyCheckpointDir, checkpointDir)
+
+ val q = inputData.toDF()
+ .writeStream
+ .format("parquet")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start(outputDir.getCanonicalPath)
+ try {
+ q.processAllAvailable()
+ // Check the query id to make sure it did use checkpoint
+ assert(q.id.toString == "09be7fb3-49d8-48a6-840d-e9c2ad92a898")
+
+ // Verify that the batch query can read "_spark_metadata" correctly after migration.
+ val df = spark.read.load(outputDir.getCanonicalPath)
+ assert(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex")
+ checkDatasetUnorderly(df.as[Int], 1, 2, 3)
+ } finally {
+ q.stop()
+ }
+ }
+ }
+
+ test("ignore the escaped path check when the flag is off") {
+ withTempDir { tempDir =>
+ setUp2dot4dot0Checkpoint(tempDir)
+ val outputDir = new File(tempDir, "output %@#output")
+ val checkpointDir = new File(tempDir, "chk %@#chk")
+
+ withSQLConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key -> "false") {
+ // Verify that the batch query ignores the legacy "_spark_metadata"
+ val df = spark.read.load(outputDir.getCanonicalPath)
+ assert(!(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex"))
+ checkDatasetUnorderly(df.as[Int], 1, 2, 3)
+
+ val inputData = MemoryStream[Int]
+ val q = inputData.toDF()
+ .writeStream
+ .format("parquet")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .start(outputDir.getCanonicalPath)
+ try {
+ q.processAllAvailable()
+ // Check the query id to make sure it ignores the legacy checkpoint
+ assert(q.id.toString != "09be7fb3-49d8-48a6-840d-e9c2ad92a898")
+ } finally {
+ q.stop()
+ }
+ }
+ }
+ }
+
+ test("containsSpecialCharsInPath") {
+ Seq("foo/b ar",
+ "/foo/b ar",
+ "file:/foo/b ar",
+ "file://foo/b ar",
+ "file:///foo/b ar",
+ "file://foo:bar@bar/foo/b ar").foreach { p =>
+ assert(StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p")
+ }
+ Seq("foo/bar",
+ "/foo/bar",
+ "file:/foo/bar",
+ "file://foo/bar",
+ "file:///foo/bar",
+ "file://foo:bar@bar/foo/bar",
+ // Special chars not in a path should not be considered as such urls won't hit the escaped
+ // path issue.
+ "file://foo:b ar@bar/foo/bar",
+ "file://foo:bar@b ar/foo/bar",
+ "file://f oo:bar@bar/foo/bar").foreach { p =>
+ assert(!StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p")
+ }
+ }
+
/** Create a streaming DF that only execute one batch in which it returns the given static DF */
private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = {
require(!triggerDF.isStreaming)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
index 10bea7f090571..59d6ac0af52a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
@@ -34,7 +34,7 @@ class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase {
}
val trigger = Trigger.Continuous(100)
- testStream(input.toDF(), useV2Sink = true)(
+ testStream(input.toDF())(
StartStream(trigger),
Execute(assertStatus),
AddData(input, 0, 1, 2),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
index d3d210c02e90d..bad22590807a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousStream, PartitionOffset}
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
@@ -43,7 +43,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
override def beforeEach(): Unit = {
super.beforeEach()
epochEndpoint = EpochCoordinatorRef.create(
- mock[StreamingWriteSupport],
+ mock[StreamingWrite],
mock[ContinuousStream],
mock[ContinuousExecution],
coordinatorId,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index 344a8aa55f0c5..9840c7f066780 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.TestSparkSession
@@ -56,7 +57,6 @@ class ContinuousSuiteBase extends StreamTest {
protected val longContinuousTrigger = Trigger.Continuous("1 hour")
override protected val defaultTrigger = Trigger.Continuous(100)
- override protected val defaultUseV2Sink = true
}
class ContinuousSuite extends ContinuousSuiteBase {
@@ -238,7 +238,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
.load()
.select('value)
- testStream(df, useV2Sink = true)(
+ testStream(df)(
StartStream(longContinuousTrigger),
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 10)),
@@ -256,7 +256,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
.load()
.select('value)
- testStream(df, useV2Sink = true)(
+ testStream(df)(
StartStream(Trigger.Continuous(2012)),
AwaitEpoch(0),
Execute(waitForRateSourceTriggers(_, 10)),
@@ -273,7 +273,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase {
.load()
.select('value)
- testStream(df, useV2Sink = true)(
+ testStream(df)(
StartStream(Trigger.Continuous(1012)),
AwaitEpoch(2),
StopStream,
@@ -343,3 +343,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase {
}
}
}
+
+class ContinuousEpochBacklogSuite extends ContinuousSuiteBase {
+ import testImplicits._
+
+ override protected def createSparkSession = new TestSparkSession(
+ new SparkContext(
+ "local[1]",
+ "continuous-stream-test-sql-context",
+ sparkConf.set("spark.sql.testkey", "true")))
+
+ // This test forces the backlog to overflow by not standing up enough executors for the query
+ // to make progress.
+ test("epoch backlog overflow") {
+ withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) {
+ val df = spark.readStream
+ .format("rate")
+ .option("numPartitions", "2")
+ .option("rowsPerSecond", "500")
+ .load()
+ .select('value)
+
+ testStream(df)(
+ StartStream(Trigger.Continuous(1)),
+ ExpectFailure[IllegalStateException] { e =>
+ e.getMessage.contains("queue has exceeded its maximum")
+ }
+ )
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
index a0b56ec17f0be..e3498db4194e8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala
@@ -17,9 +17,9 @@
package org.apache.spark.sql.streaming.continuous
+import org.mockito.{ArgumentCaptor, InOrder}
import org.mockito.ArgumentMatchers.{any, eq => eqTo}
-import org.mockito.InOrder
-import org.mockito.Mockito.{inOrder, never, verify}
+import org.mockito.Mockito._
import org.scalatest.BeforeAndAfterEach
import org.scalatest.mockito.MockitoSugar
@@ -27,9 +27,10 @@ import org.apache.spark._
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.LocalSparkSession
import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset}
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite
import org.apache.spark.sql.test.TestSparkSession
class EpochCoordinatorSuite
@@ -40,17 +41,22 @@ class EpochCoordinatorSuite
private var epochCoordinator: RpcEndpointRef = _
- private var writeSupport: StreamingWriteSupport = _
+ private var writeSupport: StreamingWrite = _
private var query: ContinuousExecution = _
private var orderVerifier: InOrder = _
+ private val epochBacklogQueueSize = 10
override def beforeEach(): Unit = {
val stream = mock[ContinuousStream]
- writeSupport = mock[StreamingWriteSupport]
+ writeSupport = mock[StreamingWrite]
query = mock[ContinuousExecution]
orderVerifier = inOrder(writeSupport, query)
- spark = new TestSparkSession()
+ spark = new TestSparkSession(
+ new SparkContext(
+ "local[2]", "test-sql-context",
+ new SparkConf().set("spark.sql.testkey", "true")
+ .set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize)))
epochCoordinator
= EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get)
@@ -186,6 +192,66 @@ class EpochCoordinatorSuite
verifyCommitsInOrderOf(List(1, 2, 3, 4, 5))
}
+ test("several epochs, max epoch backlog reached by partitionOffsets") {
+ setWriterPartitions(1)
+ setReaderPartitions(1)
+
+ reportPartitionOffset(0, 1)
+ // Commit messages not arriving
+ for (i <- 2 to epochBacklogQueueSize + 1) {
+ reportPartitionOffset(0, i)
+ }
+
+ makeSynchronousCall()
+
+ for (i <- 1 to epochBacklogQueueSize + 1) {
+ verifyNoCommitFor(i)
+ }
+ verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum")
+ }
+
+ test("several epochs, max epoch backlog reached by partitionCommits") {
+ setWriterPartitions(1)
+ setReaderPartitions(1)
+
+ commitPartitionEpoch(0, 1)
+ // Offset messages not arriving
+ for (i <- 2 to epochBacklogQueueSize + 1) {
+ commitPartitionEpoch(0, i)
+ }
+
+ makeSynchronousCall()
+
+ for (i <- 1 to epochBacklogQueueSize + 1) {
+ verifyNoCommitFor(i)
+ }
+ verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum")
+ }
+
+ test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") {
+ setWriterPartitions(2)
+ setReaderPartitions(2)
+
+ commitPartitionEpoch(0, 1)
+ reportPartitionOffset(0, 1)
+
+ // For partition 2 epoch 1 messages never arriving
+ // +2 because the first epoch not yet arrived
+ for (i <- 2 to epochBacklogQueueSize + 2) {
+ commitPartitionEpoch(0, i)
+ reportPartitionOffset(0, i)
+ commitPartitionEpoch(1, i)
+ reportPartitionOffset(1, i)
+ }
+
+ makeSynchronousCall()
+
+ for (i <- 1 to epochBacklogQueueSize + 2) {
+ verifyNoCommitFor(i)
+ }
+ verifyStoppedWithException("Size of the epoch queue has exceeded its maximum")
+ }
+
private def setWriterPartitions(numPartitions: Int): Unit = {
epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions))
}
@@ -221,4 +287,13 @@ class EpochCoordinatorSuite
private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = {
epochs.foreach(verifyCommit)
}
+
+ private def verifyStoppedWithException(msg: String): Unit = {
+ val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]);
+ verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture())
+
+ import scala.collection.JavaConverters._
+ val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg)
+ assert(throwable != null, "Stream stopped with an exception but expected message is missing")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
index 62f166602941c..7b2c1a56e8baa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala
@@ -17,6 +17,11 @@
package org.apache.spark.sql.streaming.sources
+import java.util
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper}
@@ -24,11 +29,14 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.sources.v2._
+import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
-import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
+import org.apache.spark.sql.sources.v2.writer.{WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
class FakeDataStream extends MicroBatchStream with ContinuousStream {
@@ -59,26 +67,27 @@ class FakeScanBuilder extends ScanBuilder with Scan {
override def toContinuousStream(checkpointLocation: String): ContinuousStream = new FakeDataStream
}
-trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead {
- override def name(): String = "fake"
- override def schema(): StructType = StructType(Seq())
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder
+class FakeWriteBuilder extends WriteBuilder with StreamingWrite {
+ override def buildForStreaming(): StreamingWrite = this
+ override def createStreamingWriterFactory(): StreamingDataWriterFactory = {
+ throw new IllegalStateException("fake sink - cannot actually write")
+ }
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ throw new IllegalStateException("fake sink - cannot actually write")
+ }
+ override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
+ throw new IllegalStateException("fake sink - cannot actually write")
+ }
}
-trait FakeContinuousReadTable extends Table with SupportsContinuousRead {
+trait FakeStreamingWriteTable extends Table with SupportsWrite {
override def name(): String = "fake"
override def schema(): StructType = StructType(Seq())
- override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder
-}
-
-trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider {
- override def createStreamingWriteSupport(
- queryId: String,
- schema: StructType,
- mode: OutputMode,
- options: DataSourceOptions): StreamingWriteSupport = {
- LastWriteOptions.options = options
- throw new IllegalStateException("fake sink - cannot actually write")
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(STREAMING_WRITE).asJava
+ }
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = {
+ new FakeWriteBuilder
}
}
@@ -90,9 +99,18 @@ class FakeReadMicroBatchOnly
override def keyPrefix: String = shortName()
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
LastReadOptions.options = options
- new FakeMicroBatchReadTable {}
+ new Table with SupportsRead {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Seq())
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(MICRO_BATCH_READ).asJava
+ }
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new FakeScanBuilder
+ }
+ }
}
}
@@ -104,45 +122,78 @@ class FakeReadContinuousOnly
override def keyPrefix: String = shortName()
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
LastReadOptions.options = options
- new FakeContinuousReadTable {}
+ new Table with SupportsRead {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Seq())
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(CONTINUOUS_READ).asJava
+ }
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new FakeScanBuilder
+ }
+ }
}
}
class FakeReadBothModes extends DataSourceRegister with TableProvider {
override def shortName(): String = "fake-read-microbatch-continuous"
- override def getTable(options: DataSourceOptions): Table = {
- new Table with FakeMicroBatchReadTable with FakeContinuousReadTable {}
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ new Table with SupportsRead {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Seq())
+ override def capabilities(): util.Set[TableCapability] = {
+ Set(MICRO_BATCH_READ, CONTINUOUS_READ).asJava
+ }
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new FakeScanBuilder
+ }
+ }
}
}
class FakeReadNeitherMode extends DataSourceRegister with TableProvider {
override def shortName(): String = "fake-read-neither-mode"
- override def getTable(options: DataSourceOptions): Table = {
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
new Table {
override def name(): String = "fake"
override def schema(): StructType = StructType(Nil)
+ override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
}
}
}
-class FakeWriteSupportProvider
+class FakeWriteOnly
extends DataSourceRegister
- with FakeStreamingWriteSupportProvider
+ with TableProvider
with SessionConfigSupport {
override def shortName(): String = "fake-write-microbatch-continuous"
override def keyPrefix: String = shortName()
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ LastWriteOptions.options = options
+ new Table with FakeStreamingWriteTable {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Nil)
+ }
+ }
}
-class FakeNoWrite extends DataSourceRegister {
+class FakeNoWrite extends DataSourceRegister with TableProvider {
override def shortName(): String = "fake-write-neither-mode"
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ new Table {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Nil)
+ override def capabilities(): util.Set[TableCapability] = Collections.emptySet()
+ }
+ }
}
-
case class FakeWriteV1FallbackException() extends Exception
class FakeSink extends Sink {
@@ -150,21 +201,28 @@ class FakeSink extends Sink {
}
class FakeWriteSupportProviderV1Fallback extends DataSourceRegister
- with FakeStreamingWriteSupportProvider with StreamSinkProvider {
+ with TableProvider with StreamSinkProvider {
override def createSink(
- sqlContext: SQLContext,
- parameters: Map[String, String],
- partitionColumns: Seq[String],
- outputMode: OutputMode): Sink = {
+ sqlContext: SQLContext,
+ parameters: Map[String, String],
+ partitionColumns: Seq[String],
+ outputMode: OutputMode): Sink = {
new FakeSink()
}
override def shortName(): String = "fake-write-v1-fallback"
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = {
+ new Table with FakeStreamingWriteTable {
+ override def name(): String = "fake"
+ override def schema(): StructType = StructType(Nil)
+ }
+ }
}
object LastReadOptions {
- var options: DataSourceOptions = _
+ var options: CaseInsensitiveStringMap = _
def clear(): Unit = {
options = null
@@ -172,7 +230,7 @@ object LastReadOptions {
}
object LastWriteOptions {
- var options: DataSourceOptions = _
+ var options: CaseInsensitiveStringMap = _
def clear(): Unit = {
options = null
@@ -260,7 +318,7 @@ class StreamingDataSourceV2Suite extends StreamTest {
testPositiveCaseWithQuery(
"fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query =>
assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink
- .isInstanceOf[FakeWriteSupportProviderV1Fallback])
+ .isInstanceOf[Table])
}
// Ensure we create a V1 sink with the config. Note the config is a comma separated
@@ -289,8 +347,8 @@ class StreamingDataSourceV2Suite extends StreamTest {
testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ =>
eventually(timeout(streamingTimeout)) {
// Write options should not be set.
- assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false)
- assert(LastReadOptions.options.getBoolean(readOptionName, false) == true)
+ assert(!LastWriteOptions.options.containsKey(readOptionName))
+ assert(LastReadOptions.options.getBoolean(readOptionName, false))
}
}
}
@@ -300,8 +358,8 @@ class StreamingDataSourceV2Suite extends StreamTest {
testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ =>
eventually(timeout(streamingTimeout)) {
// Read options should not be set.
- assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false)
- assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true)
+ assert(!LastReadOptions.options.containsKey(writeOptionName))
+ assert(LastWriteOptions.options.getBoolean(writeOptionName, false))
}
}
}
@@ -319,44 +377,43 @@ class StreamingDataSourceV2Suite extends StreamTest {
for ((read, write, trigger) <- cases) {
testQuietly(s"stream with read format $read, write format $write, trigger $trigger") {
- val table = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor()
- .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty())
- val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf).
- getConstructor().newInstance()
-
- (table, writeSource, trigger) match {
- // Valid microbatch queries.
- case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t)
- if !t.isInstanceOf[ContinuousTrigger] =>
- testPositiveCase(read, write, trigger)
-
- // Valid continuous queries.
- case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider,
- _: ContinuousTrigger) =>
- testPositiveCase(read, write, trigger)
+ val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor()
+ .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty())
+
+ val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor()
+ .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty())
+ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
+ trigger match {
// Invalid - can't read at all
- case (r, _, _) if !r.isInstanceOf[SupportsMicroBatchRead] &&
- !r.isInstanceOf[SupportsContinuousRead] =>
+ case _ if !sourceTable.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) =>
testNegativeCase(read, write, trigger,
s"Data source $read does not support streamed reading")
// Invalid - can't write
- case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] =>
+ case _ if !sinkTable.supports(STREAMING_WRITE) =>
testNegativeCase(read, write, trigger,
s"Data source $write does not support streamed writing")
- // Invalid - trigger is continuous but reader is not
- case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger)
- if !r.isInstanceOf[SupportsContinuousRead] =>
- testNegativeCase(read, write, trigger,
- s"Data source $read does not support continuous processing")
+ case _: ContinuousTrigger =>
+ if (sourceTable.supports(CONTINUOUS_READ)) {
+ // Valid microbatch queries.
+ testPositiveCase(read, write, trigger)
+ } else {
+ // Invalid - trigger is continuous but reader is not
+ testNegativeCase(
+ read, write, trigger, s"Data source $read does not support continuous processing")
+ }
- // Invalid - trigger is microbatch but reader is not
- case (r, _, t) if !r.isInstanceOf[SupportsMicroBatchRead] &&
- !t.isInstanceOf[ContinuousTrigger] =>
- testPostCreationNegativeCase(read, write, trigger,
- s"Data source $read does not support microbatch processing")
+ case microBatchTrigger =>
+ if (sourceTable.supports(MICRO_BATCH_READ)) {
+ // Valid continuous queries.
+ testPositiveCase(read, write, trigger)
+ } else {
+ // Invalid - trigger is microbatch but reader is not
+ testPostCreationNegativeCase(read, write, trigger,
+ s"Data source $read does not support microbatch processing")
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
index 74ea0bfacba54..99dc0769a3d69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala
@@ -359,7 +359,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
test("source metadataPath") {
LastOptions.clear()
- val checkpointLocationURI = new Path(newMetadataDir).toUri
+ val checkpointLocation = new Path(newMetadataDir)
val df1 = spark.readStream
.format("org.apache.spark.sql.streaming.test")
@@ -371,7 +371,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
val q = df1.union(df2).writeStream
.format("org.apache.spark.sql.streaming.test")
- .option("checkpointLocation", checkpointLocationURI.toString)
+ .option("checkpointLocation", checkpointLocation.toString)
.trigger(ProcessingTime(10.seconds))
.start()
q.processAllAvailable()
@@ -379,14 +379,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
verify(LastOptions.mockStreamSourceProvider).createSource(
any(),
- meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/0"),
+ meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/0"),
meq(None),
meq("org.apache.spark.sql.streaming.test"),
meq(Map.empty))
verify(LastOptions.mockStreamSourceProvider).createSource(
any(),
- meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/1"),
+ meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/1"),
meq(None),
meq("org.apache.spark.sql.streaming.test"),
meq(Map.empty))
@@ -614,6 +614,21 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
}
}
+ test("configured checkpoint dir should not be deleted if a query is stopped without errors and" +
+ " force temp checkpoint deletion enabled") {
+ import testImplicits._
+ withTempDir { checkpointPath =>
+ withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath,
+ SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") {
+ val ds = MemoryStream[Int].toDS
+ val query = ds.writeStream.format("console").start()
+ assert(checkpointPath.exists())
+ query.stop()
+ assert(checkpointPath.exists())
+ }
+ }
+ }
+
test("temp checkpoint dir should be deleted if a query is stopped without errors") {
import testImplicits._
val query = MemoryStream[Int].toDS.writeStream.format("console").start()
@@ -627,6 +642,17 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
}
testQuietly("temp checkpoint dir should not be deleted if a query is stopped with an error") {
+ testTempCheckpointWithFailedQuery(false)
+ }
+
+ testQuietly("temp checkpoint should be deleted if a query is stopped with an error and force" +
+ " temp checkpoint deletion enabled") {
+ withSQLConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") {
+ testTempCheckpointWithFailedQuery(true)
+ }
+ }
+
+ private def testTempCheckpointWithFailedQuery(checkpointMustBeDeleted: Boolean): Unit = {
import testImplicits._
val input = MemoryStream[Int]
val query = input.toDS.map(_ / 0).writeStream.format("console").start()
@@ -638,7 +664,11 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter {
intercept[StreamingQueryException] {
query.awaitTermination()
}
- assert(fs.exists(checkpointDir))
+ if (!checkpointMustBeDeleted) {
+ assert(fs.exists(checkpointDir))
+ } else {
+ assert(!fs.exists(checkpointDir))
+ }
}
test("SPARK-20431: Specify a schema by using a DDL-formatted string") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index e45ab19aadbfa..a388de1970f14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -38,10 +38,15 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression}
+import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.datasources.DataSourceUtils
+import org.apache.spark.sql.execution.datasources.noop.NoopDataSource
import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.QueryExecutionListener
import org.apache.spark.util.Utils
@@ -220,15 +225,75 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
}
test("save mode") {
- val df = spark.read
+ spark.range(10).write
.format("org.apache.spark.sql.test")
- .load()
+ .mode(SaveMode.ErrorIfExists)
+ .save()
+ assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
- df.write
+ spark.range(10).write
+ .format("org.apache.spark.sql.test")
+ .mode(SaveMode.Append)
+ .save()
+ assert(LastOptions.saveMode === SaveMode.Append)
+
+ // By default the save mode is `ErrorIfExists` for data source v1.
+ spark.range(10).write
.format("org.apache.spark.sql.test")
- .mode(SaveMode.ErrorIfExists)
.save()
assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
+
+ spark.range(10).write
+ .format("org.apache.spark.sql.test")
+ .mode("default")
+ .save()
+ assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
+ }
+
+ test("save mode for data source v2") {
+ var plan: LogicalPlan = null
+ val listener = new QueryExecutionListener {
+ override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
+ plan = qe.analyzed
+
+ }
+ override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
+ }
+
+ spark.listenerManager.register(listener)
+ try {
+ // append mode creates `AppendData`
+ spark.range(10).write
+ .format(classOf[NoopDataSource].getName)
+ .mode(SaveMode.Append)
+ .save()
+ sparkContext.listenerBus.waitUntilEmpty(1000)
+ assert(plan.isInstanceOf[AppendData])
+
+ // overwrite mode creates `OverwriteByExpression`
+ spark.range(10).write
+ .format(classOf[NoopDataSource].getName)
+ .mode(SaveMode.Overwrite)
+ .save()
+ sparkContext.listenerBus.waitUntilEmpty(1000)
+ assert(plan.isInstanceOf[OverwriteByExpression])
+
+ // By default the save mode is `ErrorIfExists` for data source v2.
+ spark.range(10).write
+ .format(classOf[NoopDataSource].getName)
+ .save()
+ sparkContext.listenerBus.waitUntilEmpty(1000)
+ assert(plan.isInstanceOf[AppendData])
+
+ spark.range(10).write
+ .format(classOf[NoopDataSource].getName)
+ .mode("default")
+ .save()
+ sparkContext.listenerBus.waitUntilEmpty(1000)
+ assert(plan.isInstanceOf[AppendData])
+ } finally {
+ spark.listenerManager.unregister(listener)
+ }
}
test("test path option in load") {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 132b0e4db0d71..84e5fae79bf16 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlanner
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck}
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState}
@@ -72,6 +73,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
new FindDataSourceTable(session) +:
new ResolveSQLOnFile(session) +:
new FallbackOrcDataSourceV2(session) +:
+ DataSourceResolution(conf, session.catalog(_)) +:
customResolutionRules
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
@@ -86,6 +88,8 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session
override val extendedCheckRules: Seq[LogicalPlan => Unit] =
PreWriteCheck +:
PreReadCheck +:
+ V2WriteSupportCheck +:
+ V2StreamingScanSupportCheck +:
customCheckRules
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 66426824573c6..a4587abbf389d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -172,7 +172,7 @@ abstract class HiveComparisonTest
// and does not return it as a query answer.
case _: SetCommand => Seq("0")
case _: ExplainCommand => answer
- case _: DescribeTableCommand | ShowColumnsCommand(_, _) =>
+ case _: DescribeCommandBase | ShowColumnsCommand(_, _) =>
// Filter out non-deterministic lines and lines which do not have actual results but
// can introduce problems because of the way Hive formats these lines.
// Then, remove empty lines. Do not sort the results.
@@ -375,7 +375,7 @@ abstract class HiveComparisonTest
if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) &&
(!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) &&
(!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) &&
- (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) &&
+ (!hiveQuery.logical.isInstanceOf[DescribeCommandBase]) &&
preparedHive != catalyst) {
val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive