diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java new file mode 100644 index 000000000000..bbb01597bc96 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/ClusteredDistribution.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.Expression; + +/** + * A distribution where tuples that share the same values for clustering expressions are co-located + * in the same partition. + */ +@Experimental +public interface ClusteredDistribution extends Distribution { + /** + * Returns clustering expressions. + */ + Expression[] clustering(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java new file mode 100644 index 000000000000..e7c6446ce700 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distribution.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions; + +import org.apache.spark.annotation.Experimental; + +/** + * An interface that defines how data is distributed across partitions. + */ +@Experimental +public interface Distribution {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java new file mode 100644 index 000000000000..3b8cd79c23bb --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/Distributions.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * Helper methods to create distributions to pass into Spark. + */ +@Experimental +public class Distributions { + private Distributions() { + } + + /** + * Creates a distribution where no promises are made about co-location of data. + */ + public static UnspecifiedDistribution unspecified() { + return LogicalDistributions.unspecified(); + } + + /** + * Creates a distribution where tuples that share the same values for clustering expressions are + * co-located in the same partition. + */ + public static ClusteredDistribution clustered(Expression[] clustering) { + return LogicalDistributions.clustered(clustering); + } + + /** + * Creates a distribution where tuples have been ordered across partitions according + * to ordering expressions, but not necessarily within a given partition. + */ + public static OrderedDistribution ordered(SortOrder[] ordering) { + return LogicalDistributions.ordered(ordering); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java new file mode 100644 index 000000000000..2c63e833bcb8 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/OrderedDistribution.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A distribution where tuples have been ordered across partitions according + * to ordering expressions, but not necessarily within a given partition. + */ +@Experimental +public interface OrderedDistribution extends Distribution { + /** + * Returns ordering expressions. + */ + SortOrder[] ordering(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java new file mode 100644 index 000000000000..cfd1e4a39feb --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/distributions/UnspecifiedDistribution.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions; + +import org.apache.spark.annotation.Experimental; + +/** + * A distribution where no promises are made about co-location of data. + */ +@Experimental +public interface UnspecifiedDistribution extends Distribution {} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java index 791dc969ab00..984de6258f84 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expressions.java @@ -164,4 +164,15 @@ public static Transform hours(String column) { return LogicalExpressions.hours(Expressions.column(column)); } + /** + * Create a sort expression. + * + * @param expr an expression to produce values to sort + * @param direction direction of the sort + * @param nullOrder null order of the sort + * @return a SortOrder + */ + public static SortOrder sort(Expression expr, SortDirection direction, NullOrdering nullOrder) { + return LogicalExpressions.sort(expr, direction, nullOrder); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java new file mode 100644 index 000000000000..3d22ce1b8929 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NullOrdering.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * A null order used in sorting expressions. + */ +@Experimental +public enum NullOrdering { + NULLS_FIRST, NULLS_LAST; + + @Override + public String toString() { + switch (this) { + case NULLS_FIRST: + return "NULLS FIRST"; + case NULLS_LAST: + return "NULLS LAST"; + default: + throw new IllegalArgumentException("Unexpected null order: " + this); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java new file mode 100644 index 000000000000..5a5d5a3f2f8b --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortDirection.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * A sort direction used in sorting expressions. + */ +@Experimental +public enum SortDirection { + ASCENDING, DESCENDING; + + @Override + public String toString() { + switch (this) { + case ASCENDING: + return "ASC"; + case DESCENDING: + return "DESC"; + default: + throw new IllegalArgumentException("Unexpected sort direction: " + this); + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java new file mode 100644 index 000000000000..97ba5c150f35 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Represents a sort order in the public expression API. + */ +@Experimental +public interface SortOrder extends Expression { + /** + * Returns the sort expression. + */ + Expression expression(); + + /** + * Returns the sort direction. + */ + SortDirection direction(); + + /** + * Returns the null ordering. + */ + NullOrdering nullOrdering(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java new file mode 100644 index 000000000000..6362480934cc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RequiresDistributionAndOrdering.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.connector.distributions.Distribution; +import org.apache.spark.sql.connector.distributions.UnspecifiedDistribution; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A write that requires a specific distribution and ordering of data. + */ +@Experimental +public interface RequiresDistributionAndOrdering extends Write { + /** + * Returns the distribution required by this write. + *

+ * Spark will distribute incoming records to satisfy the required distribution before + * passing those records to the data source table on write. + *

+ * Implementations may return {@link UnspecifiedDistribution} if they don't require any specific + * distribution of data on write. + * + * @return the required distribution + */ + Distribution requiredDistribution(); + + /** + * Returns the ordering required by this write. + *

+ * Spark will order incoming records within partitions to satisfy the required ordering + * before passing those records to the data source table on write. + *

+ * Implementations may return an empty array if they don't require any specific ordering of data + * on write. + * + * @return the required ordering + */ + SortOrder[] requiredOrdering(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java new file mode 100644 index 000000000000..7b1c4ffe319e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/Write.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.write.streaming.StreamingWrite; + +/** + * A logical representation of a data source write. + *

+ * This logical representation is shared between batch and streaming write. Data sources must + * implement the corresponding methods in this interface to match what the table promises + * to support. For example, {@link #toBatch()} must be implemented if the {@link Table} that + * creates this {@link Write} returns {@link TableCapability#BATCH_WRITE} support in its + * {@link Table#capabilities()}. + */ +public interface Write { + + default String description() { + return this.getClass().toString(); + } + + /** + * Returns a {@link BatchWrite} to write data to batch source. By default this method throws + * exception, data sources must overwrite this method to provide an implementation, if the + * {@link Table} that creates this write returns {@link TableCapability#BATCH_WRITE} support in + * its {@link Table#capabilities()}. + */ + default BatchWrite toBatch() { + throw new UnsupportedOperationException(description() + ": Batch write is not supported"); + } + + /** + * Returns a {@link StreamingWrite} to write data to streaming source. By default this method + * throws exception, data sources must overwrite this method to provide an implementation, if the + * {@link Table} that creates this write returns {@link TableCapability#STREAMING_WRITE} support + * in its {@link Table#capabilities()}. + */ + default StreamingWrite toStreaming() { + throw new UnsupportedOperationException(description() + ": Streaming write is not supported"); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/WriteBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/WriteBuilder.java index 5398ca46e977..6b556bbe5e49 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/WriteBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/WriteBuilder.java @@ -23,10 +23,10 @@ import org.apache.spark.sql.connector.write.streaming.StreamingWrite; /** - * An interface for building the {@link BatchWrite}. Implementations can mix in some interfaces to + * An interface for building the {@link Write}. Implementations can mix in some interfaces to * support different ways to write data to data sources. * - * Unless modified by a mixin interface, the {@link BatchWrite} configured by this builder is to + * Unless modified by a mixin interface, the {@link Write} configured by this builder is to * append data without affecting existing data. * * @since 3.0.0 @@ -34,6 +34,23 @@ @Evolving public interface WriteBuilder { + /** + * Returns a logical {@link Write} shared between batch and streaming. + */ + default Write build() { + return new Write() { + @Override + public BatchWrite toBatch() { + return buildForBatch(); + } + + @Override + public StreamingWrite toStreaming() { + return buildForStreaming(); + } + }; + } + /** * Returns a {@link BatchWrite} to write data to batch source. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 67056470418f..d417cb19363e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.write.Write import org.apache.spark.sql.types.{DataType, MetadataBuilder, StringType, StructType} /** @@ -65,7 +66,8 @@ case class AppendData( table: NamedRelation, query: LogicalPlan, writeOptions: Map[String, String], - isByName: Boolean) extends V2WriteCommand { + isByName: Boolean, + write: Option[Write] = None) extends V2WriteCommand { override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) } @@ -94,7 +96,8 @@ case class OverwriteByExpression( deleteExpr: Expression, query: LogicalPlan, writeOptions: Map[String, String], - isByName: Boolean) extends V2WriteCommand { + isByName: Boolean, + write: Option[Write] = None) extends V2WriteCommand { override lazy val resolved: Boolean = { table.resolved && query.resolved && outputResolved && deleteExpr.resolved } @@ -132,7 +135,8 @@ case class OverwritePartitionsDynamic( table: NamedRelation, query: LogicalPlan, writeOptions: Map[String, String], - isByName: Boolean) extends V2WriteCommand { + isByName: Boolean, + write: Option[Write] = None) extends V2WriteCommand { override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { copy(query = newQuery) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala new file mode 100644 index 000000000000..599f82b4dc52 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/distributions/distributions.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.distributions + +import org.apache.spark.sql.connector.expressions.{Expression, SortOrder} + +private[sql] object LogicalDistributions { + + def unspecified(): UnspecifiedDistribution = { + UnspecifiedDistributionImpl + } + + def clustered(clustering: Array[Expression]): ClusteredDistribution = { + ClusteredDistributionImpl(clustering) + } + + def ordered(ordering: Array[SortOrder]): OrderedDistribution = { + OrderedDistributionImpl(ordering) + } +} + +private[sql] object UnspecifiedDistributionImpl extends UnspecifiedDistribution { + override def toString: String = "UnspecifiedDistribution" +} + +private[sql] final case class ClusteredDistributionImpl( + clusteringExprs: Seq[Expression]) extends ClusteredDistribution { + + override def clustering: Array[Expression] = clusteringExprs.toArray + + override def toString: String = { + s"ClusteredDistribution(${clusteringExprs.map(_.describe).mkString(", ")})" + } +} + +private[sql] final case class OrderedDistributionImpl( + orderingExprs: Seq[SortOrder]) extends OrderedDistribution { + + override def ordering: Array[SortOrder] = orderingExprs.toArray + + override def toString: String = { + s"OrderedDistribution(${orderingExprs.map(_.describe).mkString(", ")})" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 321ea14d376b..2863d94d198b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -54,6 +54,13 @@ private[sql] object LogicalExpressions { def days(reference: NamedReference): DaysTransform = DaysTransform(reference) def hours(reference: NamedReference): HoursTransform = HoursTransform(reference) + + def sort( + reference: Expression, + direction: SortDirection, + nullOrdering: NullOrdering): SortOrder = { + SortValue(reference, direction, nullOrdering) + } } /** @@ -110,6 +117,18 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { + def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + case transform: Transform => + transform match { + case BucketTransform(n, FieldReference(parts)) => + Some((n, FieldReference(parts))) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { case NamedTransform("bucket", Seq( Lit(value: Int, IntegerType), @@ -170,6 +189,18 @@ private[sql] final case class IdentityTransform( } private[sql] object IdentityTransform { + def unapply(expr: Expression): Option[FieldReference] = expr match { + case transform: Transform => + transform match { + case IdentityTransform(ref) => + Some(ref) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[FieldReference] = transform match { case NamedTransform("identity", Seq(Ref(parts))) => Some(FieldReference(parts)) @@ -185,6 +216,18 @@ private[sql] final case class YearsTransform( } private[sql] object YearsTransform { + def unapply(expr: Expression): Option[FieldReference] = expr match { + case transform: Transform => + transform match { + case YearsTransform(ref) => + Some(ref) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[FieldReference] = transform match { case NamedTransform("years", Seq(Ref(parts))) => Some(FieldReference(parts)) @@ -200,6 +243,18 @@ private[sql] final case class MonthsTransform( } private[sql] object MonthsTransform { + def unapply(expr: Expression): Option[FieldReference] = expr match { + case transform: Transform => + transform match { + case MonthsTransform(ref) => + Some(ref) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[FieldReference] = transform match { case NamedTransform("months", Seq(Ref(parts))) => Some(FieldReference(parts)) @@ -215,6 +270,18 @@ private[sql] final case class DaysTransform( } private[sql] object DaysTransform { + def unapply(expr: Expression): Option[FieldReference] = expr match { + case transform: Transform => + transform match { + case DaysTransform(ref) => + Some(ref) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[FieldReference] = transform match { case NamedTransform("days", Seq(Ref(parts))) => Some(FieldReference(parts)) @@ -230,6 +297,18 @@ private[sql] final case class HoursTransform( } private[sql] object HoursTransform { + def unapply(expr: Expression): Option[FieldReference] = expr match { + case transform: Transform => + transform match { + case HoursTransform(ref) => + Some(ref) + case _ => + None + } + case _ => + None + } + def unapply(transform: Transform): Option[FieldReference] = transform match { case NamedTransform("hours", Seq(Ref(parts))) => Some(FieldReference(parts)) @@ -261,3 +340,20 @@ private[sql] object FieldReference { LogicalExpressions.parseReference(column) } } + +private[sql] final case class SortValue( + expression: Expression, + direction: SortDirection, + nullOrdering: NullOrdering) extends SortOrder { + + override def describe(): String = s"$expression $direction $nullOrdering" +} + +private[sql] object SortValue { + def unapply(expr: Expression): Option[(Expression, SortDirection, NullOrdering)] = expr match { + case sort: SortOrder => + Some((sort.expression, sort.direction, sort.nullOrdering)) + case _ => + None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index c4c5835d9d1f..ead18b0aecf3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -257,7 +257,7 @@ class InMemoryTable( } } - private abstract class TestBatchWrite extends BatchWrite { + protected abstract class TestBatchWrite extends BatchWrite { override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { BufferedRowsWriterFactory } @@ -265,13 +265,13 @@ class InMemoryTable( override def abort(messages: Array[WriterCommitMessage]): Unit = {} } - private object Append extends TestBatchWrite { + protected object Append extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { withData(messages.map(_.asInstanceOf[BufferedRows])) } } - private object DynamicOverwrite extends TestBatchWrite { + protected object DynamicOverwrite extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) dataMap --= newData.flatMap(_.rows.map(getKey)) @@ -279,7 +279,7 @@ class InMemoryTable( } } - private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + protected class Overwrite(filters: Array[Filter]) extends TestBatchWrite { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = InMemoryTable.filtersToKeys( @@ -289,7 +289,7 @@ class InMemoryTable( } } - private object TruncateAndAppend extends TestBatchWrite { + protected object TruncateAndAppend extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { dataMap.clear withData(messages.map(_.asInstanceOf[BufferedRows])) diff --git a/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1Write.java b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1Write.java new file mode 100644 index 000000000000..b01b0aeb6c67 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1Write.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.write; + +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.sources.InsertableRelation; + +/** + * A logical write that should be executed using V1 InsertableRelation interface. + *

+ * Tables that have {@link TableCapability#V1_BATCH_WRITE} in the list of their capabilities + * must create {@link V1WriteBuilder} and build {@link V1Write}. + */ +@Unstable +public interface V1Write extends Write { + InsertableRelation toInsertableRelation(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java index 89b567b5231a..a762a49a2af0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/connector/write/V1WriteBuilder.java @@ -33,6 +33,11 @@ */ @Unstable public interface V1WriteBuilder extends WriteBuilder { + + default Write build() { + return (V1Write) this::buildForV1Write; + } + /** * Creates an InsertableRelation that allows appending a DataFrame to a * a destination (using data source-specific parameters). The insert method will only be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 33b86a2b5340..22458e049c69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions import org.apache.spark.sql.execution.datasources.SchemaPruning -import org.apache.spark.sql.execution.datasources.v2.V2ScanRelationPushDown +import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -39,6 +39,9 @@ class SparkOptimizer( // TODO: move SchemaPruning into catalyst SchemaPruning :: V2ScanRelationPushDown :: PruneFileSourcePartitions :: Nil + override def dataSourceRewriteRules: Seq[Rule[LogicalPlan]] = + V2Writes :: Nil + override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, @@ -70,7 +73,8 @@ class SparkOptimizer( ExtractPythonUDFFromJoinCondition.ruleName :+ ExtractPythonUDFFromAggregate.ruleName :+ ExtractGroupingPythonUDFFromAggregate.ruleName :+ ExtractPythonUDFs.ruleName :+ - V2ScanRelationPushDown.ruleName + V2ScanRelationPushDown.ruleName :+ + V2Writes.ruleName /** * Optimization batches that are executed before the regular optimization batches (also before diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 5289d359f780..99076fa38534 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.catalog.{CatalogV2Util, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, TableCapability, TableCatalog, TableChange} import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -178,15 +179,20 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat orCreate = orCreate) :: Nil } - case AppendData(r: DataSourceV2Relation, query, writeOptions, _) => + case AppendData(r: DataSourceV2Relation, query, writeOptions, _, write) => r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - AppendDataExecV1(v1, writeOptions.asOptions, query, refreshCache(r)) :: Nil + AppendDataExecV1( + v1, writeOptions.asOptions, query, + refreshCache(r), write.map(_.asInstanceOf[V1Write])) :: Nil case v2 => - AppendDataExec(v2, writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil + AppendDataExec( + v2, writeOptions.asOptions, planLater(query), + refreshCache(r), write.map(_.toBatch)) :: Nil } - case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, writeOptions, _) => + case OverwriteByExpression( + r: DataSourceV2Relation, deleteExpr, query, writeOptions, _, write) => // fail if any filter cannot be converted. correctness depends on removing all matching data. val filters = splitConjunctivePredicates(deleteExpr).map { filter => DataSourceStrategy.translateFilter(deleteExpr, @@ -195,16 +201,19 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat }.toArray r.table.asWritable match { case v1 if v1.supports(TableCapability.V1_BATCH_WRITE) => - OverwriteByExpressionExecV1(v1, filters, writeOptions.asOptions, - query, refreshCache(r)) :: Nil + OverwriteByExpressionExecV1( + v1, filters, writeOptions.asOptions, query, + refreshCache(r), write.map(_.asInstanceOf[V1Write])) :: Nil case v2 => - OverwriteByExpressionExec(v2, filters, - writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil + OverwriteByExpressionExec( + v2, filters, writeOptions.asOptions, planLater(query), + refreshCache(r), write.map(_.toBatch)) :: Nil } - case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _) => + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, writeOptions, _, write) => OverwritePartitionsDynamicExec( - r.table.asWritable, writeOptions.asOptions, planLater(query), refreshCache(r)) :: Nil + r.table.asWritable, writeOptions.asOptions, planLater(query), + refreshCache(r), write.map(_.toBatch)) :: Nil case DeleteFromTable(relation, condition) => relation match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala index cb4a2994de1f..f697aba46d0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableCapabilityCheck.scala @@ -49,14 +49,14 @@ object TableCapabilityCheck extends (LogicalPlan => Unit) { // TODO: check STREAMING_WRITE capability. It's not doable now because we don't have a // a logical plan for streaming write. - case AppendData(r: DataSourceV2Relation, _, _, _) if !supportsBatchWrite(r.table) => + case AppendData(r: DataSourceV2Relation, _, _, _, _) if !supportsBatchWrite(r.table) => failAnalysis(s"Table ${r.table.name()} does not support append in batch mode.") - case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _) + case OverwritePartitionsDynamic(r: DataSourceV2Relation, _, _, _, _) if !r.table.supports(BATCH_WRITE) || !r.table.supports(OVERWRITE_DYNAMIC) => failAnalysis(s"Table ${r.table.name()} does not support dynamic overwrite in batch mode.") - case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _) => + case OverwriteByExpression(r: DataSourceV2Relation, expr, _, _, _, _) => expr match { case Literal(true, BooleanType) => if (!supportsBatchWrite(r.table) || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala index 9d2cea9fbaff..6a18e9cf88b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V1FallbackWriters.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.SupportsWrite -import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, V1WriteBuilder, WriteBuilder} import org.apache.spark.sql.execution.{AlreadyOptimized, SparkPlan} import org.apache.spark.sql.sources.{AlwaysTrue, Filter, InsertableRelation} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -38,10 +38,11 @@ case class AppendDataExecV1( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, plan: LogicalPlan, - refreshCache: () => Unit) extends V1FallbackWriters { + refreshCache: () => Unit, + override val write: Option[V1Write] = None) extends V1FallbackWriters { - override protected def run(): Seq[InternalRow] = { - writeWithV1(newWriteBuilder().buildForV1Write(), refreshCache = refreshCache) + override protected def buildAndRun(): Seq[InternalRow] = { + writeWithV1(newWriteBuilder().buildForV1Write()) } } @@ -61,20 +62,20 @@ case class OverwriteByExpressionExecV1( deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, plan: LogicalPlan, - refreshCache: () => Unit) extends V1FallbackWriters { + refreshCache: () => Unit, + override val write: Option[V1Write] = None) extends V1FallbackWriters { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] } - override protected def run(): Seq[InternalRow] = { + override protected def buildAndRun(): Seq[InternalRow] = { newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => - writeWithV1(builder.truncate().asV1Builder.buildForV1Write(), refreshCache = refreshCache) + writeWithV1(builder.truncate().asV1Builder.buildForV1Write()) case builder: SupportsOverwrite => - writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write(), - refreshCache = refreshCache) + writeWithV1(builder.overwrite(deleteWhere).asV1Builder.buildForV1Write()) case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") @@ -89,6 +90,21 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { def table: SupportsWrite def writeOptions: CaseInsensitiveStringMap + def refreshCache: () => Unit + def write: Option[V1Write] = None + + override def run(): Seq[InternalRow] = { + val writtenRows = write match { + case Some(v1Write) => + writeWithV1(v1Write.toInsertableRelation) + case _ => + buildAndRun() + } + refreshCache() + writtenRows + } + + protected def buildAndRun(): Seq[InternalRow] protected implicit class toV1WriteBuilder(builder: WriteBuilder) { def asV1Builder: V1WriteBuilder = builder match { @@ -115,14 +131,10 @@ sealed trait V1FallbackWriters extends V2CommandExec with SupportsV1Write { trait SupportsV1Write extends SparkPlan { def plan: LogicalPlan - protected def writeWithV1( - relation: InsertableRelation, - refreshCache: () => Unit = () => ()): Seq[InternalRow] = { + protected def writeWithV1(relation: InsertableRelation): Seq[InternalRow] = { val session = sqlContext.sparkSession // The `plan` is already optimized, we should not analyze and optimize it again. relation.insert(AlreadyOptimized.dataFrame(session, plan), overwrite = false) - refreshCache() - Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala new file mode 100644 index 000000000000..874f07b998ad --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.UUID + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.{NamedExpression, PredicateHelper, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, RepartitionByExpression, Sort} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, IdentityTransform, NullOrdering, SortDirection, SortValue} +import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, RequiresDistributionAndOrdering, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} + +/** + * A rule that constructs [[Write]]s. + * + * This rule does resolution in the optimizer because some nodes like [[OverwriteByExpression]] + * must undergo the expression optimization before we can construct a logical write. + */ +object V2Writes extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) => + val writeBuilder = newWriteBuilder(r.table, query, options) + val write = writeBuilder.build() + a.copy(write = Some(write), query = addDistributionAndOrdering(write, query)) + + case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).flatMap { p => + val filter = DataSourceStrategy.translateFilter(p, supportNestedPredicatePushdown = true) + if (filter.isEmpty) { + throw new AnalysisException(s"Cannot translate expression to source filter: $p") + } + filter + }.toArray + + val table = r.table + val writeBuilder = newWriteBuilder(table, query, options) + val write = writeBuilder match { + case builder: SupportsTruncate if isTruncate(filters) => + builder.truncate().build() + case builder: SupportsOverwrite => + builder.overwrite(filters).build() + case _ => + throw new SparkException(s"Table does not support overwrite by expression: $table") + } + + o.copy(write = Some(write), query = addDistributionAndOrdering(write, query)) + + case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) => + val table = r.table + val writeBuilder = newWriteBuilder(table, query, options) + val write = writeBuilder match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().build() + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + o.copy(write = Some(write), query = addDistributionAndOrdering(write, query)) + } + + private def newWriteBuilder( + table: Table, + query: LogicalPlan, + writeOptions: Map[String, String]): WriteBuilder = { + + val info = LogicalWriteInfoImpl( + queryId = UUID.randomUUID().toString, + query.schema, + writeOptions.asOptions) + table.asWritable.newWriteBuilder(info) + } + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + private def addDistributionAndOrdering( + write: Write, + query: LogicalPlan): LogicalPlan = write match { + + case write: RequiresDistributionAndOrdering => + val sqlConf = SQLConf.get + val resolver = sqlConf.resolver + + val distribution = write.requiredDistribution match { + case d: OrderedDistribution => + d.ordering.map(e => toCatalyst(e, query, resolver)) + case d: ClusteredDistribution => + d.clustering.map(e => toCatalyst(e, query, resolver)) + case _: UnspecifiedDistribution => + Array.empty[catalyst.expressions.Expression] + } + + val queryWithDistribution = if (distribution.nonEmpty) { + val numShufflePartitions = sqlConf.numShufflePartitions + // the conversion to catalyst expressions above produces SortOrder expressions + // for OrderedDistribution and generic expressions for ClusteredDistribution + // this allows RepartitionByExpression to pick either range or hash partitioning + RepartitionByExpression(distribution, query, numShufflePartitions) + } else { + query + } + + val ordering = write.requiredOrdering.toSeq + .map(e => toCatalyst(e, query, resolver)) + .asInstanceOf[Seq[catalyst.expressions.SortOrder]] + + val queryWithDistributionAndOrdering = if (ordering.nonEmpty) { + Sort(ordering, global = false, queryWithDistribution) + } else { + queryWithDistribution + } + + queryWithDistributionAndOrdering + case _ => + query + } + + private def toCatalyst( + expr: Expression, + query: LogicalPlan, + resolver: Resolver): catalyst.expressions.Expression = { + def resolve(ref: FieldReference): NamedExpression = { + // this part is controversial as we perform resolution in the optimizer + // we cannot perform this step in the analyzer since we need to optimize expressions + // in nodes like OverwriteByExpression before constructing a logical write + query.resolve(ref.parts, resolver) match { + case Some(attr) => attr + case None => throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}") + } + } + expr match { + case SortValue(child, direction, nullOrdering) => + val catalystChild = toCatalyst(child, query, resolver) + SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty) + case IdentityTransform(ref) => + resolve(ref) + case ref: FieldReference => + resolve(ref) + case _ => + throw new RuntimeException(s"$expr is not currently supported") + } + } + + private def toCatalyst(direction: SortDirection): catalyst.expressions.SortDirection = { + direction match { + case SortDirection.ASCENDING => catalyst.expressions.Ascending + case SortDirection.DESCENDING => catalyst.expressions.Descending + } + } + + private def toCatalyst(nullOrdering: NullOrdering): catalyst.expressions.NullOrdering = { + nullOrdering match { + case NullOrdering.NULLS_FIRST => catalyst.expressions.NullsFirst + case NullOrdering.NULLS_LAST => catalyst.expressions.NullsLast + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 47aad2bcb2c5..537b9162dbf1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connector.catalog.{Identifier, StagedTable, StagingTableCatalog, SupportsWrite, Table, TableCatalog} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, LogicalWriteInfoImpl, PhysicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1WriteBuilder, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, LogicalWriteInfoImpl, PhysicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -216,12 +216,12 @@ case class AppendDataExec( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan, - refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper { + refreshCache: () => Unit, + override val write: Option[BatchWrite] = None) + extends V2ExistingTableWriteExec with BatchWriteHelper { - override protected def run(): Seq[InternalRow] = { - val writtenRows = writeWithV2(newWriteBuilder().buildForBatch()) - refreshCache() - writtenRows + override protected def buildAndRun(): Seq[InternalRow] = { + writeWithV2(newWriteBuilder().buildForBatch()) } } @@ -240,14 +240,16 @@ case class OverwriteByExpressionExec( deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan, - refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper { + refreshCache: () => Unit, + override val write: Option[BatchWrite] = None) + extends V2ExistingTableWriteExec with BatchWriteHelper { private def isTruncate(filters: Array[Filter]): Boolean = { filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] } - override protected def run(): Seq[InternalRow] = { - val writtenRows = newWriteBuilder() match { + override protected def buildAndRun(): Seq[InternalRow] = { + newWriteBuilder() match { case builder: SupportsTruncate if isTruncate(deleteWhere) => writeWithV2(builder.truncate().buildForBatch()) @@ -257,8 +259,6 @@ case class OverwriteByExpressionExec( case _ => throw new SparkException(s"Table does not support overwrite by expression: $table") } - refreshCache() - writtenRows } } @@ -276,18 +276,18 @@ case class OverwritePartitionsDynamicExec( table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan, - refreshCache: () => Unit) extends V2TableWriteExec with BatchWriteHelper { + refreshCache: () => Unit, + override val write: Option[BatchWrite] = None) + extends V2ExistingTableWriteExec with BatchWriteHelper { - override protected def run(): Seq[InternalRow] = { - val writtenRows = newWriteBuilder() match { + override protected def buildAndRun(): Seq[InternalRow] = { + newWriteBuilder() match { case builder: SupportsDynamicOverwrite => writeWithV2(builder.overwriteDynamicPartitions().buildForBatch()) case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } - refreshCache() - writtenRows } } @@ -319,6 +319,24 @@ trait BatchWriteHelper { } } +trait V2ExistingTableWriteExec extends V2TableWriteExec { + def refreshCache: () => Unit + def write: Option[BatchWrite] = None + + override protected def run(): Seq[InternalRow] = { + val writtenRows = write match { + case Some(batchWrite) => + writeWithV2(batchWrite) + case _ => + buildAndRun() + } + refreshCache() + writtenRows + } + + protected def buildAndRun(): Seq[InternalRow] +} + /** * The base physical plan for writing data into data source v2. */ @@ -477,9 +495,10 @@ private[v2] trait TableWriteExecHelper extends V2TableWriteExec with SupportsV1W writeOptions) val writeBuilder = table.newWriteBuilder(info) - val writtenRows = writeBuilder match { - case v1: V1WriteBuilder => writeWithV1(v1.buildForV1Write()) - case v2 => writeWithV2(v2.buildForBatch()) + val write = writeBuilder.build() + val writtenRows = write match { + case v1: V1Write => writeWithV1(v1.toInsertableRelation) + case v2 => writeWithV2(v2.toBatch) } table match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala new file mode 100644 index 000000000000..597598b2d2a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -0,0 +1,594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{catalyst, DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.analysis.{TableAlreadyExistsException, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, RangePartitioning, UnknownPartitioning} +import org.apache.spark.sql.connector.catalog.{Identifier, Table} +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NullOrdering, SortDirection, SortOrder, Transform} +import org.apache.spark.sql.connector.expressions.LogicalExpressions._ +import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, RequiresDistributionAndOrdering, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.v2.V2TableWriteExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener} + +class WriteDistributionAndOrderingSuite + extends QueryTest with SharedSparkSession with BeforeAndAfter { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[ExtendedInMemoryTableCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.clear() + } + + private val writeOperations = Seq("append", "overwrite", "overwriteDynamic") + + private val namespace = Array("ns1") + private val ident = Identifier.of(namespace, "test_table") + private val tableNameAsString = "testcat." + ident.toString + private val emptyProps = Collections.emptyMap[String, String] + private val schema = new StructType() + .add("id", IntegerType) + .add("data", StringType) + + writeOperations.foreach { operation => + test(s"ordered distribution and sort with same exprs ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = RangePartitioning(writeOrdering, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"clustered distribution and sort with same exprs ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val clustering = Array[Expression](FieldReference("data"), FieldReference("id")) + val distribution = Distributions.clustered(clustering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data"), attr("id")) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = HashPartitioning(writePartitioningExprs, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"clustered distribution and sort with extended exprs ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val clustering = Array[Expression](FieldReference("data")) + val distribution = Distributions.clustered(clustering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + attr("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = HashPartitioning(writePartitioningExprs, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"unspecified distribution and local sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.unspecified() + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + attr("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioning = UnknownPartitioning(0) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"unspecified distribution and no sort ($operation)") { + val ordering = Array.empty[SortOrder] + val distribution = Distributions.unspecified() + + val writeOrdering = Seq.empty[catalyst.expressions.SortOrder] + val writePartitioning = UnknownPartitioning(0) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"ordered distribution and sort with manual global sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = RangePartitioning(writeOrdering, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"ordered distribution and sort with incompatible global sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = RangePartitioning(writeOrdering, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy(df("data").desc, df("id").asc), + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"ordered distribution and sort with manual local sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = RangePartitioning(writeOrdering, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.sortWithinPartitions("data", "id"), + writeOperation = operation + ) + } + } + + // TODO: do we need to dedup repartitions too? RepartitionByExpr -> Projects -> RepartitionByExpr + writeOperations.foreach { operation => + ignore(s"ordered distribution and sort with manual repartition ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.ordered(ordering) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = RangePartitioning(writeOrdering, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.repartitionByRange(df("data"), df("id")), + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"clustered distribution and local sort with manual global sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.clustered(Array(FieldReference("data"))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = HashPartitioning(writePartitioningExprs, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeOperation = operation + ) + } + } + + writeOperations.foreach { operation => + test(s"clustered distribution and local sort with manual local sort ($operation)") { + val ordering = Array[SortOrder]( + sort(FieldReference("data"), SortDirection.DESCENDING, NullOrdering.NULLS_FIRST), + sort(FieldReference("id"), SortDirection.ASCENDING, NullOrdering.NULLS_FIRST) + ) + val distribution = Distributions.clustered(Array(FieldReference("data"))) + + val writeOrdering = Seq( + catalyst.expressions.SortOrder( + UnresolvedAttribute("data"), + catalyst.expressions.Descending, + catalyst.expressions.NullsFirst, + Seq.empty + ), + catalyst.expressions.SortOrder( + UnresolvedAttribute("id"), + catalyst.expressions.Ascending, + catalyst.expressions.NullsFirst, + Seq.empty + ) + ) + val writePartitioningExprs = Seq(attr("data")) + val numShufflePartitions = SQLConf.get.numShufflePartitions + val writePartitioning = HashPartitioning(writePartitioningExprs, numShufflePartitions) + + checkWriteRequirements( + tableDistribution = distribution, + tableOrdering = ordering, + expectedWritePartitioning = writePartitioning, + expectedWriteOrdering = writeOrdering, + writeTransform = df => df.orderBy("data", "id"), + writeOperation = operation + ) + } + } + + private def checkWriteRequirements( + tableDistribution: Distribution, + tableOrdering: Array[SortOrder], + expectedWritePartitioning: physical.Partitioning, + expectedWriteOrdering: Seq[catalyst.expressions.SortOrder], + writeTransform: DataFrame => DataFrame = df => df, + writeOperation: String = "append"): Unit = { + + catalog.createTable(ident, schema, Array.empty, emptyProps, tableDistribution, tableOrdering) + + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + val writer = writeTransform(df).writeTo(tableNameAsString) + val executedPlan = writeOperation match { + case "append" => execute(writer.append()) + case "overwrite" => execute(writer.overwrite(lit(true))) + case "overwriteDynamic" => execute(writer.overwritePartitions()) + } + + checkPartitioningAndOrdering(executedPlan, expectedWritePartitioning, expectedWriteOrdering) + + checkAnswer(spark.table(tableNameAsString), df) + } + + private def checkPartitioningAndOrdering( + plan: SparkPlan, + partitioning: physical.Partitioning, + ordering: Seq[catalyst.expressions.SortOrder]): Unit = { + + val sorts = plan.collect { case s: SortExec => s } + assert(sorts.size <= 1, "must be at most one sort") + val shuffles = plan.collect { case s: ShuffleExchangeExec => s } + assert(shuffles.size <= 1, "must be at most one shuffle") + + val actualPartitioning = plan.outputPartitioning + val expectedPartitioning = partitioning match { + case p: physical.RangePartitioning => + val resolvedOrdering = p.ordering.map(resolveAttrs(_, plan)) + p.copy(ordering = resolvedOrdering.asInstanceOf[Seq[catalyst.expressions.SortOrder]]) + case p: physical.HashPartitioning => + val resolvedExprs = p.expressions.map(resolveAttrs(_, plan)) + p.copy(expressions = resolvedExprs) + case other => other + } + // TODO: can be compatible, does not have to match 100% + assert(actualPartitioning == expectedPartitioning, "partitioning must match") + + val actualOrdering = plan.outputOrdering + val expectedOrdering = ordering.map(resolveAttrs(_, plan)) + // TODO: can be compatible, does not have to match 100% + assert(actualOrdering == expectedOrdering, "ordering must match") + } + + private def resolveAttrs( + expr: catalyst.expressions.Expression, + plan: SparkPlan): catalyst.expressions.Expression = { + + expr.transform { + case UnresolvedAttribute(parts) => + val attrName = parts.mkString(",") + plan.output.find(a => a.name == attrName).get + } + } + + private def attr(name: String): UnresolvedAttribute = { + UnresolvedAttribute(name) + } + + private def catalog: ExtendedInMemoryTableCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("testcat") + catalog.asTableCatalog.asInstanceOf[ExtendedInMemoryTableCatalog] + } + + // executes a write operation and keeps the executed physical plan + private def execute(writeFunc: => Unit): SparkPlan = { + var executedPlan: SparkPlan = null + + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + executedPlan = qe.executedPlan + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + } + } + spark.listenerManager.register(listener) + + writeFunc + + sparkContext.listenerBus.waitUntilEmpty() + + executedPlan.asInstanceOf[V2TableWriteExec].query + } +} + +class ExtendedInMemoryTableCatalog extends InMemoryTableCatalog { + + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String], + distribution: Distribution, + ordering: Array[SortOrder]): Table = { + + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) + + val table = new ExtendedInMemoryTable( + s"$name.${ident.quoted}", schema, partitions, properties, distribution, ordering) + tables.put(ident, table) + namespaces.putIfAbsent(ident.namespace.toList, Map()) + table + } +} + +class ExtendedInMemoryTable( + override val name: String, + override val schema: StructType, + override val partitioning: Array[Transform], + override val properties: util.Map[String, String], + distribution: Distribution, + ordering: Array[SortOrder]) + extends InMemoryTable(name, schema, partitioning, properties) { + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + InMemoryTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) + InMemoryTable.maybeSimulateFailedTableWrite(info.options) + + new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { + private var writer: BatchWrite = Append + + override def truncate(): WriteBuilder = { + assert(writer == Append) + writer = TruncateAndAppend + this + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + this + } + + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + this + } + + override def build(): Write = new RequiresDistributionAndOrdering { + override def requiredDistribution(): Distribution = distribution + override def requiredOrdering(): Array[SortOrder] = ordering + override def toBatch: BatchWrite = writer + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 38719311f1ae..f06fd801c1c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -1176,7 +1176,7 @@ class PlanResolutionSuite extends AnalysisTest { case Project(_, AsDataSourceV2Relation(r)) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) - case AppendData(r: DataSourceV2Relation, _, _, _) => + case AppendData(r: DataSourceV2Relation, _, _, _, _) => assert(r.catalog.exists(_ == catlogIdent)) assert(r.identifier.exists(_.name() == tableIdent)) case DescribeRelation(r: ResolvedTable, _, _) =>