From 0177dea25401424997653ade1f3d99bf64f27aa3 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Thu, 28 Jul 2022 10:25:45 -0700 Subject: [PATCH 1/8] [SPARK-39914][SQL] Add DS V2 Filter to V2 Filter conversion --- .../internal/connector/PredicateUtils.scala | 73 ++++++++++++++++++- .../datasources/v2/V2PredicateSuite.scala | 64 ++++++++++++++++ 2 files changed, 134 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index ace6b30d4ccec..7a42681184938 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.connector.expressions.filter.{Predicate, And => V2And, Not => V2Not, Or => V2Or} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { predicate.name() match { - // TODO: add conversion for other V2 Predicate case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString val values = predicate.children().drop(1) @@ -43,6 +43,73 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + predicate.name() match { + case "=" => + Some(EqualTo(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + case "<=>" => + Some(EqualNullSafe(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + case ">" => + Some(GreaterThan(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + case ">=" => + Some(GreaterThanOrEqual(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + case "<" => + Some(LessThan(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + case "<=" => + Some(LessThanOrEqual(attribute, + CatalystTypeConverters.convertToScala(value.value, value.dataType))) + } + + case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + predicate.name() match { + case "IS_NULL" => Some(IsNull(attribute)) + case "IS_NOT_NULL" => Some(IsNotNull(attribute)) + } + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + predicate.name() match { + case "STARTS_WITH" => + Some(StringStartsWith(attribute, value.value.toString)) + case "ENDS_WITH" => + Some(StringEndsWith(attribute, value.value.toString)) + case "CONTAINS" => + Some(StringContains(attribute, value.value.toString)) + } + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + predicate.name() match { + case "ALWAYS_TRUE" => Some(AlwaysTrue()) + case "ALWAYS_FALSE" => Some(AlwaysFalse()) + } + + case "AND" => + val and = predicate.asInstanceOf[V2And] + Some(And(toV1(and.left()).get, toV1(and.right()).get)) + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + Some(Or(toV1(or.left()).get, toV1(or.right()).get)) + + case "NOT" => + val not = predicate.asInstanceOf[V2Not] + Some(Not(toV1(not.child()).get)) + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 2df8b8e56c44b..aeae33a30dc22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} import org.apache.spark.sql.connector.expressions.filter._ import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue => V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate1.describe.equals("a.B = 1")) val v1Filter1 = EqualTo(ref("a", "B").describe(), 1) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val predicate2 = new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) @@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate2.describe.equals("a.`b.c` = 1")) val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1) assert(v1Filter2.toV2 == predicate2) + assert(PredicateUtils.toV1(predicate2).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2) val predicate3 = new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) @@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate3.describe.equals("```a``.b`.c = 1")) val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1) assert(v1Filter3.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter3) + assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("AlwaysTrue") { @@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysTrue assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("AlwaysFalse") { @@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysFalse assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualTo") { @@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualTo("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualNullSafe") { @@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualNullSafe("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThan") { @@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThanOrEqual") { @@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThan") { @@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThanOrEqual") { @@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("In") { @@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter1 = In("a", Array(1, 2, 3, 4)) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val v1Filter2 = In("a", values.map(_.value())) assert(v1Filter2.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("IsNull") { @@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("IsNotNull") { @@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNotNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("Not") { @@ -199,6 +245,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Not(LessThan("a", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("And") { @@ -214,6 +263,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("Or") { @@ -229,6 +281,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringStartsWith") { @@ -243,6 +298,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringEndsWith") { @@ -257,6 +315,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringContains") { @@ -271,6 +332,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } } From 0eee4a0ace08d51e345fd0d2ee8ec6cc705fb3fa Mon Sep 17 00:00:00 2001 From: huaxingao Date: Thu, 28 Jul 2022 11:11:03 -0700 Subject: [PATCH 2/8] fix style --- .../apache/spark/sql/internal/connector/PredicateUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 7a42681184938..24ea73ec9b65c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.{Predicate, And => V2And, Not => V2Not, Or => V2Or} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.StringType From a167e611f418b1a6956c7b88a1a26677e6237251 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 29 Jul 2022 09:58:50 -0700 Subject: [PATCH 3/8] extract common code to a method --- .../connector/catalog/SupportsDeleteV2.java | 82 +++ .../internal/connector/PredicateUtils.scala | 21 +- .../connector/catalog/InMemoryBaseTable.scala | 677 ++++++++++++++++++ 3 files changed, 774 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java new file mode 100644 index 0000000000000..64961045c8894 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +/** + * A mix-in interface for {@link Table} delete support. Data sources can implement this + * interface to provide the ability to delete data from tables that matches filter expressions. + * + * @since 3.4.0 + */ +@Evolving +public interface SupportsDeleteV2 extends TruncatableTable { + + /** + * Checks whether it is possible to delete data from a data source table that matches filter + * expressions. + *

+ * Rows should be deleted from the data source iff all of the filter expressions match. + * That is, the expressions must be interpreted as a set of filters that are ANDed together. + *

+ * Spark will call this method at planning time to check whether {@link #deleteWhere(Predicate[])} + * would reject the delete operation because it requires significant effort. If this method + * returns false, Spark will not call {@link #deleteWhere(Predicate[])} and will try to rewrite + * the delete operation and produce row-level changes if the data source table supports deleting + * individual records. + * + * @param predicates V2 filter expressions, used to select rows to delete when all expressions match + * @return true if the delete operation can be performed + * + * @since 3.4.0 + */ + default boolean canDeleteWhere(Predicate[] predicates) { + return true; + } + + /** + * Delete data from a data source table that matches filter expressions. Note that this method + * will be invoked only if {@link #canDeleteWhere(Predicate[])} returns true. + *

+ * Rows are deleted from the data source iff all of the filter expressions match. That is, the + * expressions must be interpreted as a set of filters that are ANDed together. + *

+ * Implementations may reject a delete operation if the delete isn't possible without significant + * effort. For example, partitioned data sources may reject deletes that do not filter by + * partition columns because the filter may require rewriting files without deleted records. + * To reject a delete implementations should throw {@link IllegalArgumentException} with a clear + * error message that identifies which expression was rejected. + * + * @param predicates predicate expressions, used to select rows to delete when all expressions match + * @throws IllegalArgumentException If the delete is rejected due to required effort + */ + void deleteWhere(Predicate[] predicates); + + @Override + default boolean truncateTable() { + Predicate[] filters = new Predicate[] { new AlwaysTrue() }; + boolean canDelete = canDeleteWhere(filters); + if (canDelete) { + deleteWhere(filters); + } + return canDelete; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 24ea73ec9b65c..309d938b1b96c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -26,6 +26,17 @@ import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if (predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]]) { + true + } else { + false + } + } + predicate.name() match { case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString @@ -43,9 +54,7 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } - case "=" | "<=>" | ">" | "<" | ">=" | "<=" if predicate.children().length == 2 && - predicate.children()(0).isInstanceOf[NamedReference] && - predicate.children()(1).isInstanceOf[LiteralValue[_]] => + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] predicate.name() match { @@ -77,9 +86,7 @@ private[sql] object PredicateUtils { case "IS_NOT_NULL" => Some(IsNotNull(attribute)) } - case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if predicate.children().length == 2 && - predicate.children()(0).isInstanceOf[NamedReference] && - predicate.children()(1).isInstanceOf[LiteralValue[_]] => + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] if (!value.dataType.sameType(StringType)) return None @@ -113,4 +120,6 @@ private[sql] object PredicateUtils { case _ => None } } + + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala new file mode 100644 index 0000000000000..c76a9969077cb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -0,0 +1,677 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoUnit +import java.util +import java.util.OptionalLong + +import scala.collection.mutable + +import org.scalatest.Assertions._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} +import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} +import org.apache.spark.sql.connector.expressions._ +import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} +import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} +import org.apache.spark.sql.connector.write._ +import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} +import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.unsafe.types.UTF8String + +/** + * A simple in-memory table. Rows are stored as a buffered group produced by each output task. + */ +class InMemoryBaseTable( + val name: String, + val schema: StructType, + override val partitioning: Array[Transform], + override val properties: util.Map[String, String], + val distribution: Distribution = Distributions.unspecified(), + val ordering: Array[SortOrder] = Array.empty, + val numPartitions: Option[Int] = None, + val isDistributionStrictlyRequired: Boolean = true) + extends Table with SupportsRead with SupportsWrite + with SupportsMetadataColumns { + + protected object PartitionKeyColumn extends MetadataColumn { + override def name: String = "_partition" + override def dataType: DataType = StringType + override def comment: String = "Partition key used to store the row" + } + + private object IndexColumn extends MetadataColumn { + override def name: String = "index" + override def dataType: DataType = IntegerType + override def comment: String = "Metadata column used to conflict with a data column" + } + + // purposely exposes a metadata column that conflicts with a data column in some tests + override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) + private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name) + + private val allowUnsupportedTransforms = + properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean + + partitioning.foreach { + case _: IdentityTransform => + case _: YearsTransform => + case _: MonthsTransform => + case _: DaysTransform => + case _: HoursTransform => + case _: BucketTransform => + case _: SortedBucketTransform => + case t if !allowUnsupportedTransforms => + throw new IllegalArgumentException(s"Transform $t is not a supported transform") + } + + // The key `Seq[Any]` is the partition values. + val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty + + def data: Array[BufferedRows] = dataMap.values.toArray + + def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq + + private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => + schema.findNestedField(ref.fieldNames(), includeCollections = false) match { + case Some(_) => ref.fieldNames() + case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") + } + } + + private val UTC = ZoneId.of("UTC") + private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate + + protected def getKey(row: InternalRow): Seq[Any] = { + getKey(row, schema) + } + + protected def getKey(row: InternalRow, rowSchema: StructType): Seq[Any] = { + @scala.annotation.tailrec + def extractor( + fieldNames: Array[String], + schema: StructType, + row: InternalRow): (Any, DataType) = { + val index = schema.fieldIndex(fieldNames(0)) + val value = row.toSeq(schema).apply(index) + if (fieldNames.length > 1) { + (value, schema(index).dataType) match { + case (row: InternalRow, nestedSchema: StructType) => + extractor(fieldNames.drop(1), nestedSchema, row) + case (_, dataType) => + throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") + } + } else { + (value, schema(index).dataType) + } + } + + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(rowSchema) + partitioning.map { + case IdentityTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row)._1 + case YearsTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days: Int, DateType) => + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case MonthsTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days: Int, DateType) => + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) + case (micros: Long, TimestampType) => + val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate + ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case DaysTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (days, DateType) => + days + case (micros: Long, TimestampType) => + ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case HoursTransform(ref) => + extractor(ref.fieldNames, cleanedSchema, row) match { + case (micros: Long, TimestampType) => + ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) + case (v, t) => + throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") + } + case BucketTransform(numBuckets, cols, _) => + val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) + var valueHashCode = 0 + valueTypePairs.foreach( pair => + if ( pair._1 != null) valueHashCode += pair._1.hashCode() + ) + var dataTypeHashCode = 0 + valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) + ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets + } + } + + protected def addPartitionKey(key: Seq[Any]): Unit = {} + + protected def renamePartitionKey( + partitionSchema: StructType, + from: Seq[Any], + to: Seq[Any]): Boolean = { + val rows = dataMap.remove(from).getOrElse(new BufferedRows(from)) + val newRows = new BufferedRows(to) + rows.rows.foreach { r => + val newRow = new GenericInternalRow(r.numFields) + for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) + for (i <- 0 until partitionSchema.length) { + val j = schema.fieldIndex(partitionSchema(i).name) + newRow.update(j, to(i)) + } + newRows.withRow(newRow) + } + dataMap.put(to, newRows).foreach { _ => + throw new IllegalStateException( + s"The ${to.mkString("[", ", ", "]")} partition exists already") + } + true + } + + protected def removePartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { + dataMap.remove(key) + } + + protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { + if (!dataMap.contains(key)) { + val emptyRows = new BufferedRows(key) + val rows = if (key.length == schema.length) { + emptyRows.withRow(InternalRow.fromSeq(key)) + } else emptyRows + dataMap.put(key, rows) + } + } + + protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized { + assert(dataMap.contains(key)) + dataMap(key).clear() + } + + def withData(data: Array[BufferedRows]): InMemoryBaseTable = { + withData(data, schema) + } + + def withData( + data: Array[BufferedRows], + writeSchema: StructType): InMemoryBaseTable = dataMap.synchronized { + data.foreach(_.rows.foreach { row => + val key = getKey(row, writeSchema) + dataMap += dataMap.get(key) + .map(key -> _.withRow(row)) + .getOrElse(key -> new BufferedRows(key).withRow(row)) + addPartitionKey(key) + }) + this + } + + override def capabilities: util.Set[TableCapability] = util.EnumSet.of( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.STREAMING_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC, + TableCapability.TRUNCATE) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new InMemoryScanBuilder(schema) + } + + class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder + with SupportsPushDownRequiredColumns { + private var schema: StructType = tableSchema + + override def build: Scan = + new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) + + override def pruneColumns(requiredSchema: StructType): Unit = { + val schemaNames = metadataColumnNames ++ tableSchema.map(_.name) + schema = StructType(requiredSchema.filter(f => schemaNames.contains(f.name))) + } + } + + case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics + + abstract class BatchScanBaseClass( + var data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning { + + override def toBatch: Batch = this + + override def estimateStatistics(): Statistics = { + if (data.isEmpty) { + return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L)) + } + + val inputPartitions = data.map(_.asInstanceOf[BufferedRows]) + val numRows = inputPartitions.map(_.rows.size).sum + // we assume an average object header is 12 bytes + val objectHeaderSizeInBytes = 12L + val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize + val sizeInBytes = numRows * rowSizeInBytes + InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows)) + } + + override def outputPartitioning(): Partitioning = { + if (InMemoryBaseTable.this.partitioning.nonEmpty) { + new KeyGroupedPartitioning( + InMemoryBaseTable.this.partitioning.map(_.asInstanceOf[Expression]), + data.size) + } else { + new UnknownPartitioning(data.size) + } + } + + override def planInputPartitions(): Array[InputPartition] = data.toArray + + override def createReaderFactory(): PartitionReaderFactory = { + val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains) + val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) + new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) + } + } + + case class InMemoryBatchScan( + var _data: Seq[InputPartition], + readSchema: StructType, + tableSchema: StructType) + extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering { + + override def filterAttributes(): Array[NamedReference] = { + val scanFields = readSchema.fields.map(_.name).toSet + partitioning.flatMap(_.references) + .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) + } + + override def filter(filters: Array[Filter]): Unit = { + if (partitioning.length == 1 && partitioning.head.references().length == 1) { + val ref = partitioning.head.references().head + filters.foreach { + case In(attrName, values) if attrName == ref.toString => + val matchingKeys = values.map(_.toString).toSet + data = data.filter(partition => { + val key = partition.asInstanceOf[BufferedRows].keyString + matchingKeys.contains(key) + }) + + case _ => // skip + } + } + } + } + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) + InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) + + new WriteBuilder with SupportsTruncate with SupportsOverwrite + with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { + + private var writer: BatchWrite = Append + private var streamingWriter: StreamingWrite = StreamingAppend + + override def truncate(): WriteBuilder = { + assert(writer == Append) + writer = TruncateAndAppend + streamingWriter = StreamingTruncateAndAppend + this + } + + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + streamingWriter = new StreamingNotSupportedOperation( + s"overwrite (${filters.mkString("filters(", ", ", ")")})") + this + } + + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") + this + } + + override def build(): Write = new Write with RequiresDistributionAndOrdering { + override def requiredDistribution: Distribution = distribution + + override def distributionStrictlyRequired = isDistributionStrictlyRequired + + override def requiredOrdering: Array[SortOrder] = ordering + + override def requiredNumPartitions(): Int = { + numPartitions.getOrElse(0) + } + + override def toBatch: BatchWrite = writer + + override def toStreaming: StreamingWrite = streamingWriter match { + case exc: StreamingNotSupportedOperation => exc.throwsException() + case s => s + } + + override def supportedCustomMetrics(): Array[CustomMetric] = { + Array(new InMemorySimpleCustomMetric) + } + } + } + } + + protected abstract class TestBatchWrite extends BatchWrite { + override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + + private object Append extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private object DynamicOverwrite extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + dataMap --= newData.flatMap(_.rows.map(getKey)) + withData(newData) + } + } + + private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = InMemoryBaseTable.filtersToKeys( + dataMap.keys, partCols.map(_.toSeq.quoted), filters) + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private object TruncateAndAppend extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + + private abstract class TestStreamingWrite extends StreamingWrite { + def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { + BufferedRowsWriterFactory + } + + def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + } + + private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { + override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = + throwsException() + + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = + throwsException() + + def throwsException[T](): T = throw new IllegalStateException("The operation " + + s"${operation} isn't supported for streaming query.") + } + + private object StreamingAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + + private object StreamingTruncateAndAppend extends TestStreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + dataMap.synchronized { + dataMap.clear + withData(messages.map(_.asInstanceOf[BufferedRows])) + } + } + } + +// override def canDeleteWhere(filters: Array[Filter]): Boolean = { +// InMemoryTable.supportsFilters(filters) +// } +// +// override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { +// import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper +// dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) +// } +} + +object InMemoryBaseTable { + val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite" + + def filtersToKeys( + keys: Iterable[Seq[Any]], + partitionNames: Seq[String], + filters: Array[Filter]): Iterable[Seq[Any]] = { + keys.filter { partValues => + filters.flatMap(splitAnd).forall { + case EqualTo(attr, value) => + value == extractValue(attr, partitionNames, partValues) + case EqualNullSafe(attr, value) => + val attrVal = extractValue(attr, partitionNames, partValues) + if (attrVal == null && value === null) { + true + } else if (attrVal == null || value === null) { + false + } else { + value == attrVal + } + case IsNull(attr) => + null == extractValue(attr, partitionNames, partValues) + case IsNotNull(attr) => + null != extractValue(attr, partitionNames, partValues) + case AlwaysTrue() => true + case f => + throw new IllegalArgumentException(s"Unsupported filter type: $f") + } + } + } + + def supportsFilters(filters: Array[Filter]): Boolean = { + filters.flatMap(splitAnd).forall { + case _: EqualTo => true + case _: EqualNullSafe => true + case _: IsNull => true + case _: IsNotNull => true + case _: AlwaysTrue => true + case _ => false + } + } + + private def extractValue( + attr: String, + partFieldNames: Seq[String], + partValues: Seq[Any]): Any = { + partFieldNames.zipWithIndex.find(_._1 == attr) match { + case Some((_, partIndex)) => + partValues(partIndex) + case _ => + throw new IllegalArgumentException(s"Unknown filter attribute: $attr") + } + } + + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } + + def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = { + if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) { + throw new IllegalStateException("Manual write to table failure.") + } + } +} + +class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage + with InputPartition with HasPartitionKey with Serializable { + val rows = new mutable.ArrayBuffer[InternalRow]() + + def withRow(row: InternalRow): BufferedRows = { + rows.append(row) + this + } + + def keyString(): String = key.toArray.mkString("/") + + override def partitionKey(): InternalRow = { + InternalRow.fromSeq(key) + } + + def clear(): Unit = rows.clear() +} + +private class BufferedRowsReaderFactory( + metadataColumnNames: Seq[String], + nonMetaDataColumns: Seq[StructField], + tableSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new BufferedRowsReader(partition.asInstanceOf[BufferedRows], metadataColumnNames, + nonMetaDataColumns, tableSchema) + } +} + +private class BufferedRowsReader( + partition: BufferedRows, + metadataColumnNames: Seq[String], + nonMetadataColumns: Seq[StructField], + tableSchema: StructType) extends PartitionReader[InternalRow] { + private def addMetadata(row: InternalRow): InternalRow = { + val metadataRow = new GenericInternalRow(metadataColumnNames.map { + case "index" => index + case "_partition" => UTF8String.fromString(partition.keyString) + }.toArray) + new JoinedRow(row, metadataRow) + } + + private var index: Int = -1 + + override def next(): Boolean = { + index += 1 + index < partition.rows.length + } + + override def get(): InternalRow = { + val originalRow = partition.rows(index) + val values = new Array[Any](nonMetadataColumns.length) + nonMetadataColumns.zipWithIndex.foreach { case (col, idx) => + values(idx) = extractFieldValue(col, tableSchema, originalRow) + } + addMetadata(new GenericInternalRow(values)) + } + + override def close(): Unit = {} + + private def extractFieldValue( + field: StructField, + schema: StructType, + row: InternalRow): Any = { + val index = schema.fieldIndex(field.name) + field.dataType match { + case StructType(fields) => + if (row.isNullAt(index)) { + return null + } + val childRow = row.toSeq(schema)(index).asInstanceOf[InternalRow] + val childSchema = schema(index).dataType.asInstanceOf[StructType] + val resultValue = new Array[Any](fields.length) + fields.zipWithIndex.foreach { case (childField, idx) => + val childValue = extractFieldValue(childField, childSchema, childRow) + resultValue(idx) = childValue + } + new GenericInternalRow(resultValue) + case dt => + row.get(index, dt) + } + } +} + +private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new BufferWriter + } + + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): DataWriter[InternalRow] = { + new BufferWriter + } +} + +private class BufferWriter extends DataWriter[InternalRow] { + private val buffer = new BufferedRows + + override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) + + override def commit(): WriterCommitMessage = buffer + + override def abort(): Unit = {} + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + val metric = new CustomTaskMetric { + override def name(): String = "in_memory_buffer_rows" + + override def value(): Long = buffer.rows.size + } + Array(metric) + } +} + +class InMemorySimpleCustomMetric extends CustomMetric { + override def name(): String = "in_memory_buffer_rows" + override def description(): String = "number of rows in buffer" + override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { + s"in-memory rows: ${taskMetrics.sum}" + } +} From 469c3b5be88710a85dc366a1d470cbddef05262b Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 29 Jul 2022 10:01:19 -0700 Subject: [PATCH 4/8] Revert "extract common code to a method" This reverts commit a167e611f418b1a6956c7b88a1a26677e6237251. --- .../connector/catalog/SupportsDeleteV2.java | 82 --- .../internal/connector/PredicateUtils.scala | 21 +- .../connector/catalog/InMemoryBaseTable.scala | 677 ------------------ 3 files changed, 6 insertions(+), 774 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java deleted file mode 100644 index 64961045c8894..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsDeleteV2.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue; -import org.apache.spark.sql.connector.expressions.filter.Predicate; - -/** - * A mix-in interface for {@link Table} delete support. Data sources can implement this - * interface to provide the ability to delete data from tables that matches filter expressions. - * - * @since 3.4.0 - */ -@Evolving -public interface SupportsDeleteV2 extends TruncatableTable { - - /** - * Checks whether it is possible to delete data from a data source table that matches filter - * expressions. - *

- * Rows should be deleted from the data source iff all of the filter expressions match. - * That is, the expressions must be interpreted as a set of filters that are ANDed together. - *

- * Spark will call this method at planning time to check whether {@link #deleteWhere(Predicate[])} - * would reject the delete operation because it requires significant effort. If this method - * returns false, Spark will not call {@link #deleteWhere(Predicate[])} and will try to rewrite - * the delete operation and produce row-level changes if the data source table supports deleting - * individual records. - * - * @param predicates V2 filter expressions, used to select rows to delete when all expressions match - * @return true if the delete operation can be performed - * - * @since 3.4.0 - */ - default boolean canDeleteWhere(Predicate[] predicates) { - return true; - } - - /** - * Delete data from a data source table that matches filter expressions. Note that this method - * will be invoked only if {@link #canDeleteWhere(Predicate[])} returns true. - *

- * Rows are deleted from the data source iff all of the filter expressions match. That is, the - * expressions must be interpreted as a set of filters that are ANDed together. - *

- * Implementations may reject a delete operation if the delete isn't possible without significant - * effort. For example, partitioned data sources may reject deletes that do not filter by - * partition columns because the filter may require rewriting files without deleted records. - * To reject a delete implementations should throw {@link IllegalArgumentException} with a clear - * error message that identifies which expression was rejected. - * - * @param predicates predicate expressions, used to select rows to delete when all expressions match - * @throws IllegalArgumentException If the delete is rejected due to required effort - */ - void deleteWhere(Predicate[] predicates); - - @Override - default boolean truncateTable() { - Predicate[] filters = new Predicate[] { new AlwaysTrue() }; - boolean canDelete = canDeleteWhere(filters); - if (canDelete) { - deleteWhere(filters); - } - return canDelete; - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 309d938b1b96c..24ea73ec9b65c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -26,17 +26,6 @@ import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { - - def isValidBinaryPredicate(): Boolean = { - if (predicate.children().length == 2 && - predicate.children()(0).isInstanceOf[NamedReference] && - predicate.children()(1).isInstanceOf[LiteralValue[_]]) { - true - } else { - false - } - } - predicate.name() match { case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString @@ -54,7 +43,9 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } - case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] predicate.name() match { @@ -86,7 +77,9 @@ private[sql] object PredicateUtils { case "IS_NOT_NULL" => Some(IsNotNull(attribute)) } - case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]] => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] if (!value.dataType.sameType(StringType)) return None @@ -120,6 +113,4 @@ private[sql] object PredicateUtils { case _ => None } } - - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala deleted file mode 100644 index c76a9969077cb..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ /dev/null @@ -1,677 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog - -import java.time.{Instant, ZoneId} -import java.time.temporal.ChronoUnit -import java.util -import java.util.OptionalLong - -import scala.collection.mutable - -import org.scalatest.Assertions._ - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} -import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} -import org.apache.spark.sql.connector.expressions._ -import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} -import org.apache.spark.sql.connector.read._ -import org.apache.spark.sql.connector.read.partitioning.{KeyGroupedPartitioning, Partitioning, UnknownPartitioning} -import org.apache.spark.sql.connector.write._ -import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} -import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.unsafe.types.UTF8String - -/** - * A simple in-memory table. Rows are stored as a buffered group produced by each output task. - */ -class InMemoryBaseTable( - val name: String, - val schema: StructType, - override val partitioning: Array[Transform], - override val properties: util.Map[String, String], - val distribution: Distribution = Distributions.unspecified(), - val ordering: Array[SortOrder] = Array.empty, - val numPartitions: Option[Int] = None, - val isDistributionStrictlyRequired: Boolean = true) - extends Table with SupportsRead with SupportsWrite - with SupportsMetadataColumns { - - protected object PartitionKeyColumn extends MetadataColumn { - override def name: String = "_partition" - override def dataType: DataType = StringType - override def comment: String = "Partition key used to store the row" - } - - private object IndexColumn extends MetadataColumn { - override def name: String = "index" - override def dataType: DataType = IntegerType - override def comment: String = "Metadata column used to conflict with a data column" - } - - // purposely exposes a metadata column that conflicts with a data column in some tests - override val metadataColumns: Array[MetadataColumn] = Array(IndexColumn, PartitionKeyColumn) - private val metadataColumnNames = metadataColumns.map(_.name).toSet -- schema.map(_.name) - - private val allowUnsupportedTransforms = - properties.getOrDefault("allow-unsupported-transforms", "false").toBoolean - - partitioning.foreach { - case _: IdentityTransform => - case _: YearsTransform => - case _: MonthsTransform => - case _: DaysTransform => - case _: HoursTransform => - case _: BucketTransform => - case _: SortedBucketTransform => - case t if !allowUnsupportedTransforms => - throw new IllegalArgumentException(s"Transform $t is not a supported transform") - } - - // The key `Seq[Any]` is the partition values. - val dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty - - def data: Array[BufferedRows] = dataMap.values.toArray - - def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - - private val partCols: Array[Array[String]] = partitioning.flatMap(_.references).map { ref => - schema.findNestedField(ref.fieldNames(), includeCollections = false) match { - case Some(_) => ref.fieldNames() - case None => throw new IllegalArgumentException(s"${ref.describe()} does not exist.") - } - } - - private val UTC = ZoneId.of("UTC") - private val EPOCH_LOCAL_DATE = Instant.EPOCH.atZone(UTC).toLocalDate - - protected def getKey(row: InternalRow): Seq[Any] = { - getKey(row, schema) - } - - protected def getKey(row: InternalRow, rowSchema: StructType): Seq[Any] = { - @scala.annotation.tailrec - def extractor( - fieldNames: Array[String], - schema: StructType, - row: InternalRow): (Any, DataType) = { - val index = schema.fieldIndex(fieldNames(0)) - val value = row.toSeq(schema).apply(index) - if (fieldNames.length > 1) { - (value, schema(index).dataType) match { - case (row: InternalRow, nestedSchema: StructType) => - extractor(fieldNames.drop(1), nestedSchema, row) - case (_, dataType) => - throw new IllegalArgumentException(s"Unsupported type, ${dataType.simpleString}") - } - } else { - (value, schema(index).dataType) - } - } - - val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(rowSchema) - partitioning.map { - case IdentityTransform(ref) => - extractor(ref.fieldNames, cleanedSchema, row)._1 - case YearsTransform(ref) => - extractor(ref.fieldNames, cleanedSchema, row) match { - case (days: Int, DateType) => - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) - case (micros: Long, TimestampType) => - val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate - ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) - case (v, t) => - throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") - } - case MonthsTransform(ref) => - extractor(ref.fieldNames, cleanedSchema, row) match { - case (days: Int, DateType) => - ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) - case (micros: Long, TimestampType) => - val localDate = DateTimeUtils.microsToInstant(micros).atZone(UTC).toLocalDate - ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, localDate) - case (v, t) => - throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") - } - case DaysTransform(ref) => - extractor(ref.fieldNames, cleanedSchema, row) match { - case (days, DateType) => - days - case (micros: Long, TimestampType) => - ChronoUnit.DAYS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) - case (v, t) => - throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") - } - case HoursTransform(ref) => - extractor(ref.fieldNames, cleanedSchema, row) match { - case (micros: Long, TimestampType) => - ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) - case (v, t) => - throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") - } - case BucketTransform(numBuckets, cols, _) => - val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row)) - var valueHashCode = 0 - valueTypePairs.foreach( pair => - if ( pair._1 != null) valueHashCode += pair._1.hashCode() - ) - var dataTypeHashCode = 0 - valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode()) - ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets - } - } - - protected def addPartitionKey(key: Seq[Any]): Unit = {} - - protected def renamePartitionKey( - partitionSchema: StructType, - from: Seq[Any], - to: Seq[Any]): Boolean = { - val rows = dataMap.remove(from).getOrElse(new BufferedRows(from)) - val newRows = new BufferedRows(to) - rows.rows.foreach { r => - val newRow = new GenericInternalRow(r.numFields) - for (i <- 0 until r.numFields) newRow.update(i, r.get(i, schema(i).dataType)) - for (i <- 0 until partitionSchema.length) { - val j = schema.fieldIndex(partitionSchema(i).name) - newRow.update(j, to(i)) - } - newRows.withRow(newRow) - } - dataMap.put(to, newRows).foreach { _ => - throw new IllegalStateException( - s"The ${to.mkString("[", ", ", "]")} partition exists already") - } - true - } - - protected def removePartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { - dataMap.remove(key) - } - - protected def createPartitionKey(key: Seq[Any]): Unit = dataMap.synchronized { - if (!dataMap.contains(key)) { - val emptyRows = new BufferedRows(key) - val rows = if (key.length == schema.length) { - emptyRows.withRow(InternalRow.fromSeq(key)) - } else emptyRows - dataMap.put(key, rows) - } - } - - protected def clearPartition(key: Seq[Any]): Unit = dataMap.synchronized { - assert(dataMap.contains(key)) - dataMap(key).clear() - } - - def withData(data: Array[BufferedRows]): InMemoryBaseTable = { - withData(data, schema) - } - - def withData( - data: Array[BufferedRows], - writeSchema: StructType): InMemoryBaseTable = dataMap.synchronized { - data.foreach(_.rows.foreach { row => - val key = getKey(row, writeSchema) - dataMap += dataMap.get(key) - .map(key -> _.withRow(row)) - .getOrElse(key -> new BufferedRows(key).withRow(row)) - addPartitionKey(key) - }) - this - } - - override def capabilities: util.Set[TableCapability] = util.EnumSet.of( - TableCapability.BATCH_READ, - TableCapability.BATCH_WRITE, - TableCapability.STREAMING_WRITE, - TableCapability.OVERWRITE_BY_FILTER, - TableCapability.OVERWRITE_DYNAMIC, - TableCapability.TRUNCATE) - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - new InMemoryScanBuilder(schema) - } - - class InMemoryScanBuilder(tableSchema: StructType) extends ScanBuilder - with SupportsPushDownRequiredColumns { - private var schema: StructType = tableSchema - - override def build: Scan = - new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition]), schema, tableSchema) - - override def pruneColumns(requiredSchema: StructType): Unit = { - val schemaNames = metadataColumnNames ++ tableSchema.map(_.name) - schema = StructType(requiredSchema.filter(f => schemaNames.contains(f.name))) - } - } - - case class InMemoryStats(sizeInBytes: OptionalLong, numRows: OptionalLong) extends Statistics - - abstract class BatchScanBaseClass( - var data: Seq[InputPartition], - readSchema: StructType, - tableSchema: StructType) - extends Scan with Batch with SupportsReportStatistics with SupportsReportPartitioning { - - override def toBatch: Batch = this - - override def estimateStatistics(): Statistics = { - if (data.isEmpty) { - return InMemoryStats(OptionalLong.of(0L), OptionalLong.of(0L)) - } - - val inputPartitions = data.map(_.asInstanceOf[BufferedRows]) - val numRows = inputPartitions.map(_.rows.size).sum - // we assume an average object header is 12 bytes - val objectHeaderSizeInBytes = 12L - val rowSizeInBytes = objectHeaderSizeInBytes + schema.defaultSize - val sizeInBytes = numRows * rowSizeInBytes - InMemoryStats(OptionalLong.of(sizeInBytes), OptionalLong.of(numRows)) - } - - override def outputPartitioning(): Partitioning = { - if (InMemoryBaseTable.this.partitioning.nonEmpty) { - new KeyGroupedPartitioning( - InMemoryBaseTable.this.partitioning.map(_.asInstanceOf[Expression]), - data.size) - } else { - new UnknownPartitioning(data.size) - } - } - - override def planInputPartitions(): Array[InputPartition] = data.toArray - - override def createReaderFactory(): PartitionReaderFactory = { - val metadataColumns = readSchema.map(_.name).filter(metadataColumnNames.contains) - val nonMetadataColumns = readSchema.filterNot(f => metadataColumns.contains(f.name)) - new BufferedRowsReaderFactory(metadataColumns, nonMetadataColumns, tableSchema) - } - } - - case class InMemoryBatchScan( - var _data: Seq[InputPartition], - readSchema: StructType, - tableSchema: StructType) - extends BatchScanBaseClass (_data, readSchema, tableSchema) with SupportsRuntimeFiltering { - - override def filterAttributes(): Array[NamedReference] = { - val scanFields = readSchema.fields.map(_.name).toSet - partitioning.flatMap(_.references) - .filter(ref => scanFields.contains(ref.fieldNames.mkString("."))) - } - - override def filter(filters: Array[Filter]): Unit = { - if (partitioning.length == 1 && partitioning.head.references().length == 1) { - val ref = partitioning.head.references().head - filters.foreach { - case In(attrName, values) if attrName == ref.toString => - val matchingKeys = values.map(_.toString).toSet - data = data.filter(partition => { - val key = partition.asInstanceOf[BufferedRows].keyString - matchingKeys.contains(key) - }) - - case _ => // skip - } - } - } - } - - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties)) - InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options) - - new WriteBuilder with SupportsTruncate with SupportsOverwrite - with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend { - - private var writer: BatchWrite = Append - private var streamingWriter: StreamingWrite = StreamingAppend - - override def truncate(): WriteBuilder = { - assert(writer == Append) - writer = TruncateAndAppend - streamingWriter = StreamingTruncateAndAppend - this - } - - override def overwrite(filters: Array[Filter]): WriteBuilder = { - assert(writer == Append) - writer = new Overwrite(filters) - streamingWriter = new StreamingNotSupportedOperation( - s"overwrite (${filters.mkString("filters(", ", ", ")")})") - this - } - - override def overwriteDynamicPartitions(): WriteBuilder = { - assert(writer == Append) - writer = DynamicOverwrite - streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions") - this - } - - override def build(): Write = new Write with RequiresDistributionAndOrdering { - override def requiredDistribution: Distribution = distribution - - override def distributionStrictlyRequired = isDistributionStrictlyRequired - - override def requiredOrdering: Array[SortOrder] = ordering - - override def requiredNumPartitions(): Int = { - numPartitions.getOrElse(0) - } - - override def toBatch: BatchWrite = writer - - override def toStreaming: StreamingWrite = streamingWriter match { - case exc: StreamingNotSupportedOperation => exc.throwsException() - case s => s - } - - override def supportedCustomMetrics(): Array[CustomMetric] = { - Array(new InMemorySimpleCustomMetric) - } - } - } - } - - protected abstract class TestBatchWrite extends BatchWrite { - override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { - BufferedRowsWriterFactory - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = {} - } - - private object Append extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - - private object DynamicOverwrite extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val newData = messages.map(_.asInstanceOf[BufferedRows]) - dataMap --= newData.flatMap(_.rows.map(getKey)) - withData(newData) - } - } - - private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - val deleteKeys = InMemoryBaseTable.filtersToKeys( - dataMap.keys, partCols.map(_.toSeq.quoted), filters) - dataMap --= deleteKeys - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - - private object TruncateAndAppend extends TestBatchWrite { - override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - dataMap.clear - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - - private abstract class TestStreamingWrite extends StreamingWrite { - def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = { - BufferedRowsWriterFactory - } - - def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - } - - private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite { - override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = - throwsException() - - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = - throwsException() - - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = - throwsException() - - def throwsException[T](): T = throw new IllegalStateException("The operation " + - s"${operation} isn't supported for streaming query.") - } - - private object StreamingAppend extends TestStreamingWrite { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - dataMap.synchronized { - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - } - - private object StreamingTruncateAndAppend extends TestStreamingWrite { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { - dataMap.synchronized { - dataMap.clear - withData(messages.map(_.asInstanceOf[BufferedRows])) - } - } - } - -// override def canDeleteWhere(filters: Array[Filter]): Boolean = { -// InMemoryTable.supportsFilters(filters) -// } -// -// override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized { -// import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper -// dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters) -// } -} - -object InMemoryBaseTable { - val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite" - - def filtersToKeys( - keys: Iterable[Seq[Any]], - partitionNames: Seq[String], - filters: Array[Filter]): Iterable[Seq[Any]] = { - keys.filter { partValues => - filters.flatMap(splitAnd).forall { - case EqualTo(attr, value) => - value == extractValue(attr, partitionNames, partValues) - case EqualNullSafe(attr, value) => - val attrVal = extractValue(attr, partitionNames, partValues) - if (attrVal == null && value === null) { - true - } else if (attrVal == null || value === null) { - false - } else { - value == attrVal - } - case IsNull(attr) => - null == extractValue(attr, partitionNames, partValues) - case IsNotNull(attr) => - null != extractValue(attr, partitionNames, partValues) - case AlwaysTrue() => true - case f => - throw new IllegalArgumentException(s"Unsupported filter type: $f") - } - } - } - - def supportsFilters(filters: Array[Filter]): Boolean = { - filters.flatMap(splitAnd).forall { - case _: EqualTo => true - case _: EqualNullSafe => true - case _: IsNull => true - case _: IsNotNull => true - case _: AlwaysTrue => true - case _ => false - } - } - - private def extractValue( - attr: String, - partFieldNames: Seq[String], - partValues: Seq[Any]): Any = { - partFieldNames.zipWithIndex.find(_._1 == attr) match { - case Some((_, partIndex)) => - partValues(partIndex) - case _ => - throw new IllegalArgumentException(s"Unknown filter attribute: $attr") - } - } - - private def splitAnd(filter: Filter): Seq[Filter] = { - filter match { - case And(left, right) => splitAnd(left) ++ splitAnd(right) - case _ => filter :: Nil - } - } - - def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = { - if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) { - throw new IllegalStateException("Manual write to table failure.") - } - } -} - -class BufferedRows(val key: Seq[Any] = Seq.empty) extends WriterCommitMessage - with InputPartition with HasPartitionKey with Serializable { - val rows = new mutable.ArrayBuffer[InternalRow]() - - def withRow(row: InternalRow): BufferedRows = { - rows.append(row) - this - } - - def keyString(): String = key.toArray.mkString("/") - - override def partitionKey(): InternalRow = { - InternalRow.fromSeq(key) - } - - def clear(): Unit = rows.clear() -} - -private class BufferedRowsReaderFactory( - metadataColumnNames: Seq[String], - nonMetaDataColumns: Seq[StructField], - tableSchema: StructType) extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - new BufferedRowsReader(partition.asInstanceOf[BufferedRows], metadataColumnNames, - nonMetaDataColumns, tableSchema) - } -} - -private class BufferedRowsReader( - partition: BufferedRows, - metadataColumnNames: Seq[String], - nonMetadataColumns: Seq[StructField], - tableSchema: StructType) extends PartitionReader[InternalRow] { - private def addMetadata(row: InternalRow): InternalRow = { - val metadataRow = new GenericInternalRow(metadataColumnNames.map { - case "index" => index - case "_partition" => UTF8String.fromString(partition.keyString) - }.toArray) - new JoinedRow(row, metadataRow) - } - - private var index: Int = -1 - - override def next(): Boolean = { - index += 1 - index < partition.rows.length - } - - override def get(): InternalRow = { - val originalRow = partition.rows(index) - val values = new Array[Any](nonMetadataColumns.length) - nonMetadataColumns.zipWithIndex.foreach { case (col, idx) => - values(idx) = extractFieldValue(col, tableSchema, originalRow) - } - addMetadata(new GenericInternalRow(values)) - } - - override def close(): Unit = {} - - private def extractFieldValue( - field: StructField, - schema: StructType, - row: InternalRow): Any = { - val index = schema.fieldIndex(field.name) - field.dataType match { - case StructType(fields) => - if (row.isNullAt(index)) { - return null - } - val childRow = row.toSeq(schema)(index).asInstanceOf[InternalRow] - val childSchema = schema(index).dataType.asInstanceOf[StructType] - val resultValue = new Array[Any](fields.length) - fields.zipWithIndex.foreach { case (childField, idx) => - val childValue = extractFieldValue(childField, childSchema, childRow) - resultValue(idx) = childValue - } - new GenericInternalRow(resultValue) - case dt => - row.get(index, dt) - } - } -} - -private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory { - override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - new BufferWriter - } - - override def createWriter( - partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { - new BufferWriter - } -} - -private class BufferWriter extends DataWriter[InternalRow] { - private val buffer = new BufferedRows - - override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) - - override def commit(): WriterCommitMessage = buffer - - override def abort(): Unit = {} - - override def close(): Unit = {} - - override def currentMetricsValues(): Array[CustomTaskMetric] = { - val metric = new CustomTaskMetric { - override def name(): String = "in_memory_buffer_rows" - - override def value(): Long = buffer.rows.size - } - Array(metric) - } -} - -class InMemorySimpleCustomMetric extends CustomMetric { - override def name(): String = "in_memory_buffer_rows" - override def description(): String = "number of rows in buffer" - override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = { - s"in-memory rows: ${taskMetrics.sum}" - } -} From 6cc27eac9853b8b6f0241ac7f13cd36cbd406641 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 29 Jul 2022 10:05:20 -0700 Subject: [PATCH 5/8] extract common code to a method --- .../internal/connector/PredicateUtils.scala | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 24ea73ec9b65c..5f38b7053ec61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -26,6 +26,17 @@ import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if (predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]]) { + true + } else { + false + } + } + predicate.name() match { case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString @@ -43,9 +54,7 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } - case "=" | "<=>" | ">" | "<" | ">=" | "<=" if predicate.children().length == 2 && - predicate.children()(0).isInstanceOf[NamedReference] && - predicate.children()(1).isInstanceOf[LiteralValue[_]] => + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] predicate.name() match { @@ -77,9 +86,7 @@ private[sql] object PredicateUtils { case "IS_NOT_NULL" => Some(IsNotNull(attribute)) } - case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if predicate.children().length == 2 && - predicate.children()(0).isInstanceOf[NamedReference] && - predicate.children()(1).isInstanceOf[LiteralValue[_]] => + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] if (!value.dataType.sameType(StringType)) return None From 30dedd7380ca9f26a4541dc319c0314843127470 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 29 Jul 2022 10:25:29 -0700 Subject: [PATCH 6/8] addresss comments --- .../sql/internal/connector/PredicateUtils.scala | 16 ++++++++++++++-- .../datasources/v2/V2PredicateSuite.scala | 12 ++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 5f38b7053ec61..1aefafd22f07a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -107,11 +107,23 @@ private[sql] object PredicateUtils { case "AND" => val and = predicate.asInstanceOf[V2And] - Some(And(toV1(and.left()).get, toV1(and.right()).get)) + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } case "OR" => val or = predicate.asInstanceOf[V2Or] - Some(Or(toV1(or.left()).get, toV1(or.right()).get)) + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else { + None + } case "NOT" => val not = predicate.asInstanceOf[V2Not] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index aeae33a30dc22..7e77f18a405be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -266,6 +266,12 @@ class V2PredicateSuite extends SparkFunSuite { assert(PredicateUtils.toV1(predicate1).get == v1Filter) assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("Or") { @@ -284,6 +290,12 @@ class V2PredicateSuite extends SparkFunSuite { assert(PredicateUtils.toV1(predicate1).get == v1Filter) assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new Or( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("StringStartsWith") { From 485fe808946196e4b33fd308f32679e998508c7c Mon Sep 17 00:00:00 2001 From: huaxingao Date: Fri, 29 Jul 2022 21:50:01 -0700 Subject: [PATCH 7/8] fix Or.toV1 --- .../sql/internal/connector/PredicateUtils.scala | 4 +++- .../datasources/v2/V2PredicateSuite.scala | 14 +++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index 1aefafd22f07a..a19b7ded2baae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -121,8 +121,10 @@ private[sql] object PredicateUtils { val right = toV1(or.right()) if (left.nonEmpty && right.nonEmpty) { Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left } else { - None + right } case "NOT" => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 7e77f18a405be..0025313f9486d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -291,11 +291,15 @@ class V2PredicateSuite extends SparkFunSuite { assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) - val predicate3 = new Or( - new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), - new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), - LiteralValue(1, IntegerType)))) - assert(PredicateUtils.toV1(predicate3) == None) + val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate3 = new Or(left, + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left)) + + val predicate4 = new Or( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate4) == None) } test("StringStartsWith") { From 45386770c6aaf74da4e5031d1d4e58cb5c803f85 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Mon, 1 Aug 2022 00:05:55 -0700 Subject: [PATCH 8/8] address comments --- .../internal/connector/PredicateUtils.scala | 60 +++++++++---------- .../datasources/v2/V2PredicateSuite.scala | 5 ++ 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index a19b7ded2baae..263edd82197bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -57,53 +57,47 @@ private[sql] object PredicateUtils { case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] - predicate.name() match { - case "=" => - Some(EqualTo(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) - case "<=>" => - Some(EqualNullSafe(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) - case ">" => - Some(GreaterThan(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) - case ">=" => - Some(GreaterThanOrEqual(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) - case "<" => - Some(LessThan(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) - case "<=" => - Some(LessThanOrEqual(attribute, - CatalystTypeConverters.convertToScala(value.value, value.dataType))) + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) } + Some(v1Filter) case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 && predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString - predicate.name() match { - case "IS_NULL" => Some(IsNull(attribute)) - case "IS_NOT_NULL" => Some(IsNotNull(attribute)) + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) } + Some(v1Filter) case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] if (!value.dataType.sameType(StringType)) return None - predicate.name() match { + val v1Value = value.value.toString + val v1Filter = predicate.name() match { case "STARTS_WITH" => - Some(StringStartsWith(attribute, value.value.toString)) + StringStartsWith(attribute, v1Value) case "ENDS_WITH" => - Some(StringEndsWith(attribute, value.value.toString)) + StringEndsWith(attribute, v1Value) case "CONTAINS" => - Some(StringContains(attribute, value.value.toString)) + StringContains(attribute, v1Value) } + Some(v1Filter) case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => - predicate.name() match { - case "ALWAYS_TRUE" => Some(AlwaysTrue()) - case "ALWAYS_FALSE" => Some(AlwaysFalse()) + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() } + Some(v1Filter) case "AND" => val and = predicate.asInstanceOf[V2And] @@ -128,8 +122,12 @@ private[sql] object PredicateUtils { } case "NOT" => - val not = predicate.asInstanceOf[V2Not] - Some(Not(toV1(not.child()).get)) + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } case _ => None } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 0025313f9486d..de556c50f5d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -248,6 +248,11 @@ class V2PredicateSuite extends SparkFunSuite { assert(PredicateUtils.toV1(predicate1).get == v1Filter) assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new Not( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("And") {