diff --git a/pom.xml b/pom.xml index 5e29ed5792518..acebec1e2ec3e 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 2.3 - 3.2.0 + 3.2.1 10.14.2.0 1.12.3 diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 465541abdaa9e..dccf48b6991b0 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -9012,9 +9012,9 @@ def add_prefix(self, prefix: str) -> "DataFrame": 2 3 5 3 4 6 """ - assert isinstance(prefix, str) + f = partial("{prefix}{}".format, prefix=prefix) return self._apply_series_op( - lambda psser: psser.rename(tuple([prefix + i for i in psser._column_label])) + lambda psser: psser.rename(tuple([f(i) for i in psser._column_label])) ) def add_suffix(self, suffix: str) -> "DataFrame": @@ -9057,9 +9057,9 @@ def add_suffix(self, suffix: str) -> "DataFrame": 2 3 5 3 4 6 """ - assert isinstance(suffix, str) + f = partial("{}{suffix}".format, suffix=suffix) return self._apply_series_op( - lambda psser: psser.rename(tuple([i + suffix for i in psser._column_label])) + lambda psser: psser.rename(tuple([f(i) for i in psser._column_label])) ) # TODO: include, and exclude should be implemented. diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index ff4c7fcc8f140..a5540d8acc5d4 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -3170,11 +3170,10 @@ def add_prefix(self, prefix: str) -> "Series": item_3 4 dtype: int64 """ - assert isinstance(prefix, str) internal = self._internal.resolved_copy sdf = internal.spark_frame.select( [ - F.concat(SF.lit(prefix), index_spark_column).alias(index_spark_column_name) + F.concat(SF.lit(str(prefix)), index_spark_column).alias(index_spark_column_name) for index_spark_column, index_spark_column_name in zip( internal.index_spark_columns, internal.index_spark_column_names ) @@ -3225,11 +3224,10 @@ def add_suffix(self, suffix: str) -> "Series": 3_item 4 dtype: int64 """ - assert isinstance(suffix, str) internal = self._internal.resolved_copy sdf = internal.spark_frame.select( [ - F.concat(index_spark_column, SF.lit(suffix)).alias(index_spark_column_name) + F.concat(index_spark_column, SF.lit(str(suffix))).alias(index_spark_column_name) for index_spark_column, index_spark_column_name in zip( internal.index_spark_columns, internal.index_spark_column_names ) diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 1361c44404a3d..a9b8ba9c96f65 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -2701,6 +2701,8 @@ def test_add_prefix(self): pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4)) psdf = ps.from_pandas(pdf) self.assert_eq(pdf.add_prefix("col_"), psdf.add_prefix("col_")) + self.assert_eq(pdf.add_prefix(1.1), psdf.add_prefix(1.1)) + self.assert_eq(pdf.add_prefix(True), psdf.add_prefix(True)) columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) pdf.columns = columns @@ -2711,6 +2713,8 @@ def test_add_suffix(self): pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4)) psdf = ps.from_pandas(pdf) self.assert_eq(pdf.add_suffix("first_series"), psdf.add_suffix("first_series")) + self.assert_eq(pdf.add_suffix(1.1), psdf.add_suffix(1.1)) + self.assert_eq(pdf.add_suffix(True), psdf.add_suffix(True)) columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")]) pdf.columns = columns diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 144df0f986a70..6ef7c805507b1 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -1293,6 +1293,8 @@ def test_add_prefix(self): pser = pd.Series([1, 2, 3, 4], name="0") psser = ps.from_pandas(pser) self.assert_eq(pser.add_prefix("item_"), psser.add_prefix("item_")) + self.assert_eq(pser.add_prefix(1.1), psser.add_prefix(1.1)) + self.assert_eq(pser.add_prefix(False), psser.add_prefix(False)) pser = pd.Series( [1, 2, 3], @@ -1306,6 +1308,8 @@ def test_add_suffix(self): pser = pd.Series([1, 2, 3, 4], name="0") psser = ps.from_pandas(pser) self.assert_eq(pser.add_suffix("_item"), psser.add_suffix("_item")) + self.assert_eq(pser.add_suffix(1.1), psser.add_suffix(1.1)) + self.assert_eq(pser.add_suffix(False), psser.add_suffix(False)) pser = pd.Series( [1, 2, 3], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index ace6b30d4ccec..263edd82197bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,14 +19,25 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if (predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]]) { + true + } else { + false + } + } + predicate.name() match { - // TODO: add conversion for other V2 Predicate case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString val values = predicate.children().drop(1) @@ -43,6 +54,81 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) + } + Some(v1Filter) + + case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) + } + Some(v1Filter) + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + val v1Value = value.value.toString + val v1Filter = predicate.name() match { + case "STARTS_WITH" => + StringStartsWith(attribute, v1Value) + case "ENDS_WITH" => + StringEndsWith(attribute, v1Value) + case "CONTAINS" => + StringContains(attribute, v1Value) + } + Some(v1Filter) + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() + } + Some(v1Filter) + + case "AND" => + val and = predicate.asInstanceOf[V2And] + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left + } else { + right + } + + case "NOT" => + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 2df8b8e56c44b..de556c50f5d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} import org.apache.spark.sql.connector.expressions.filter._ import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue => V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate1.describe.equals("a.B = 1")) val v1Filter1 = EqualTo(ref("a", "B").describe(), 1) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val predicate2 = new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) @@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate2.describe.equals("a.`b.c` = 1")) val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1) assert(v1Filter2.toV2 == predicate2) + assert(PredicateUtils.toV1(predicate2).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2) val predicate3 = new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) @@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate3.describe.equals("```a``.b`.c = 1")) val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1) assert(v1Filter3.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter3) + assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("AlwaysTrue") { @@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysTrue assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("AlwaysFalse") { @@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysFalse assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualTo") { @@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualTo("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualNullSafe") { @@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualNullSafe("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThan") { @@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThanOrEqual") { @@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThan") { @@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThanOrEqual") { @@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("In") { @@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter1 = In("a", Array(1, 2, 3, 4)) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val v1Filter2 = In("a", values.map(_.value())) assert(v1Filter2.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("IsNull") { @@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("IsNotNull") { @@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNotNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("Not") { @@ -199,6 +245,14 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Not(LessThan("a", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new Not( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("And") { @@ -214,6 +268,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("Or") { @@ -229,6 +292,19 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate3 = new Or(left, + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left)) + + val predicate4 = new Or( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate4) == None) } test("StringStartsWith") { @@ -243,6 +319,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringEndsWith") { @@ -257,6 +336,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringContains") { @@ -271,6 +353,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } }