diff --git a/pom.xml b/pom.xml
index 5e29ed5792518..acebec1e2ec3e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -128,7 +128,7 @@
2.3
- 3.2.0
+ 3.2.1
10.14.2.0
1.12.3
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index 465541abdaa9e..dccf48b6991b0 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -9012,9 +9012,9 @@ def add_prefix(self, prefix: str) -> "DataFrame":
2 3 5
3 4 6
"""
- assert isinstance(prefix, str)
+ f = partial("{prefix}{}".format, prefix=prefix)
return self._apply_series_op(
- lambda psser: psser.rename(tuple([prefix + i for i in psser._column_label]))
+ lambda psser: psser.rename(tuple([f(i) for i in psser._column_label]))
)
def add_suffix(self, suffix: str) -> "DataFrame":
@@ -9057,9 +9057,9 @@ def add_suffix(self, suffix: str) -> "DataFrame":
2 3 5
3 4 6
"""
- assert isinstance(suffix, str)
+ f = partial("{}{suffix}".format, suffix=suffix)
return self._apply_series_op(
- lambda psser: psser.rename(tuple([i + suffix for i in psser._column_label]))
+ lambda psser: psser.rename(tuple([f(i) for i in psser._column_label]))
)
# TODO: include, and exclude should be implemented.
diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py
index ff4c7fcc8f140..a5540d8acc5d4 100644
--- a/python/pyspark/pandas/series.py
+++ b/python/pyspark/pandas/series.py
@@ -3170,11 +3170,10 @@ def add_prefix(self, prefix: str) -> "Series":
item_3 4
dtype: int64
"""
- assert isinstance(prefix, str)
internal = self._internal.resolved_copy
sdf = internal.spark_frame.select(
[
- F.concat(SF.lit(prefix), index_spark_column).alias(index_spark_column_name)
+ F.concat(SF.lit(str(prefix)), index_spark_column).alias(index_spark_column_name)
for index_spark_column, index_spark_column_name in zip(
internal.index_spark_columns, internal.index_spark_column_names
)
@@ -3225,11 +3224,10 @@ def add_suffix(self, suffix: str) -> "Series":
3_item 4
dtype: int64
"""
- assert isinstance(suffix, str)
internal = self._internal.resolved_copy
sdf = internal.spark_frame.select(
[
- F.concat(index_spark_column, SF.lit(suffix)).alias(index_spark_column_name)
+ F.concat(index_spark_column, SF.lit(str(suffix))).alias(index_spark_column_name)
for index_spark_column, index_spark_column_name in zip(
internal.index_spark_columns, internal.index_spark_column_names
)
diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py
index 1361c44404a3d..a9b8ba9c96f65 100644
--- a/python/pyspark/pandas/tests/test_dataframe.py
+++ b/python/pyspark/pandas/tests/test_dataframe.py
@@ -2701,6 +2701,8 @@ def test_add_prefix(self):
pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4))
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.add_prefix("col_"), psdf.add_prefix("col_"))
+ self.assert_eq(pdf.add_prefix(1.1), psdf.add_prefix(1.1))
+ self.assert_eq(pdf.add_prefix(True), psdf.add_prefix(True))
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")])
pdf.columns = columns
@@ -2711,6 +2713,8 @@ def test_add_suffix(self):
pdf = pd.DataFrame({"A": [1, 2, 3, 4], "B": [3, 4, 5, 6]}, index=np.random.rand(4))
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.add_suffix("first_series"), psdf.add_suffix("first_series"))
+ self.assert_eq(pdf.add_suffix(1.1), psdf.add_suffix(1.1))
+ self.assert_eq(pdf.add_suffix(True), psdf.add_suffix(True))
columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B")])
pdf.columns = columns
diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py
index 144df0f986a70..6ef7c805507b1 100644
--- a/python/pyspark/pandas/tests/test_series.py
+++ b/python/pyspark/pandas/tests/test_series.py
@@ -1293,6 +1293,8 @@ def test_add_prefix(self):
pser = pd.Series([1, 2, 3, 4], name="0")
psser = ps.from_pandas(pser)
self.assert_eq(pser.add_prefix("item_"), psser.add_prefix("item_"))
+ self.assert_eq(pser.add_prefix(1.1), psser.add_prefix(1.1))
+ self.assert_eq(pser.add_prefix(False), psser.add_prefix(False))
pser = pd.Series(
[1, 2, 3],
@@ -1306,6 +1308,8 @@ def test_add_suffix(self):
pser = pd.Series([1, 2, 3, 4], name="0")
psser = ps.from_pandas(pser)
self.assert_eq(pser.add_suffix("_item"), psser.add_suffix("_item"))
+ self.assert_eq(pser.add_suffix(1.1), psser.add_suffix(1.1))
+ self.assert_eq(pser.add_suffix(False), psser.add_suffix(False))
pser = pd.Series(
[1, 2, 3],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
index ace6b30d4ccec..263edd82197bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
@@ -19,14 +19,25 @@ package org.apache.spark.sql.internal.connector
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference}
-import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.sources.{Filter, In}
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
+import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith}
+import org.apache.spark.sql.types.StringType
private[sql] object PredicateUtils {
def toV1(predicate: Predicate): Option[Filter] = {
+
+ def isValidBinaryPredicate(): Boolean = {
+ if (predicate.children().length == 2 &&
+ predicate.children()(0).isInstanceOf[NamedReference] &&
+ predicate.children()(1).isInstanceOf[LiteralValue[_]]) {
+ true
+ } else {
+ false
+ }
+ }
+
predicate.name() match {
- // TODO: add conversion for other V2 Predicate
case "IN" if predicate.children()(0).isInstanceOf[NamedReference] =>
val attribute = predicate.children()(0).toString
val values = predicate.children().drop(1)
@@ -43,6 +54,81 @@ private[sql] object PredicateUtils {
Some(In(attribute, Array.empty[Any]))
}
+ case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate =>
+ val attribute = predicate.children()(0).toString
+ val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+ val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType)
+ val v1Filter = predicate.name() match {
+ case "=" => EqualTo(attribute, v1Value)
+ case "<=>" => EqualNullSafe(attribute, v1Value)
+ case ">" => GreaterThan(attribute, v1Value)
+ case ">=" => GreaterThanOrEqual(attribute, v1Value)
+ case "<" => LessThan(attribute, v1Value)
+ case "<=" => LessThanOrEqual(attribute, v1Value)
+ }
+ Some(v1Filter)
+
+ case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 &&
+ predicate.children()(0).isInstanceOf[NamedReference] =>
+ val attribute = predicate.children()(0).toString
+ val v1Filter = predicate.name() match {
+ case "IS_NULL" => IsNull(attribute)
+ case "IS_NOT_NULL" => IsNotNull(attribute)
+ }
+ Some(v1Filter)
+
+ case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate =>
+ val attribute = predicate.children()(0).toString
+ val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
+ if (!value.dataType.sameType(StringType)) return None
+ val v1Value = value.value.toString
+ val v1Filter = predicate.name() match {
+ case "STARTS_WITH" =>
+ StringStartsWith(attribute, v1Value)
+ case "ENDS_WITH" =>
+ StringEndsWith(attribute, v1Value)
+ case "CONTAINS" =>
+ StringContains(attribute, v1Value)
+ }
+ Some(v1Filter)
+
+ case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty =>
+ val v1Filter = predicate.name() match {
+ case "ALWAYS_TRUE" => AlwaysTrue()
+ case "ALWAYS_FALSE" => AlwaysFalse()
+ }
+ Some(v1Filter)
+
+ case "AND" =>
+ val and = predicate.asInstanceOf[V2And]
+ val left = toV1(and.left())
+ val right = toV1(and.right())
+ if (left.nonEmpty && right.nonEmpty) {
+ Some(And(left.get, right.get))
+ } else {
+ None
+ }
+
+ case "OR" =>
+ val or = predicate.asInstanceOf[V2Or]
+ val left = toV1(or.left())
+ val right = toV1(or.right())
+ if (left.nonEmpty && right.nonEmpty) {
+ Some(Or(left.get, right.get))
+ } else if (left.nonEmpty) {
+ left
+ } else {
+ right
+ }
+
+ case "NOT" =>
+ val child = toV1(predicate.asInstanceOf[V2Not].child())
+ if (child.nonEmpty) {
+ Some(Not(child.get))
+ } else {
+ None
+ }
+
case _ => None
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
index 2df8b8e56c44b..de556c50f5d4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue}
import org.apache.spark.sql.connector.expressions.filter._
import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref
+import org.apache.spark.sql.internal.connector.PredicateUtils
import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue => V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith}
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite {
assert(predicate1.describe.equals("a.B = 1"))
val v1Filter1 = EqualTo(ref("a", "B").describe(), 1)
assert(v1Filter1.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter1)
+ assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
val predicate2 =
new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType)))
@@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite {
assert(predicate2.describe.equals("a.`b.c` = 1"))
val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1)
assert(v1Filter2.toV2 == predicate2)
+ assert(PredicateUtils.toV1(predicate2).get == v1Filter2)
+ assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2)
+ assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2)
val predicate3 =
new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType)))
@@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite {
assert(predicate3.describe.equals("```a``.b`.c = 1"))
val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1)
assert(v1Filter3.toV2 == predicate3)
+ assert(PredicateUtils.toV1(predicate3).get == v1Filter3)
+ assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3)
+ assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3)
}
test("AlwaysTrue") {
@@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = V1AlwaysTrue
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("AlwaysFalse") {
@@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = V1AlwaysFalse
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("EqualTo") {
@@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = EqualTo("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("EqualNullSafe") {
@@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = EqualNullSafe("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("LessThan") {
@@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = LessThan("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("LessThanOrEqual") {
@@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = LessThanOrEqual("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("GreatThan") {
@@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = GreaterThan("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("GreatThanOrEqual") {
@@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = GreaterThanOrEqual("a", 1)
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("In") {
@@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter1 = In("a", Array(1, 2, 3, 4))
assert(v1Filter1.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter1)
+ assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
val v1Filter2 = In("a", values.map(_.value()))
assert(v1Filter2.toV2 == predicate3)
+ assert(PredicateUtils.toV1(predicate3).get == v1Filter2)
+ assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2)
+ assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3)
}
test("IsNull") {
@@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = IsNull("a")
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("IsNotNull") {
@@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = IsNotNull("a")
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("Not") {
@@ -199,6 +245,14 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = V1Not(LessThan("a", 1))
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+ val predicate3 = new Not(
+ new Predicate("=", Array[Expression](LiteralValue(1, IntegerType),
+ LiteralValue(1, IntegerType))))
+ assert(PredicateUtils.toV1(predicate3) == None)
}
test("And") {
@@ -214,6 +268,15 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1))
assert(v1Filter.toV2 == predicate1)
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+ val predicate3 = new And(
+ new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](LiteralValue(1, IntegerType),
+ LiteralValue(1, IntegerType))))
+ assert(PredicateUtils.toV1(predicate3) == None)
}
test("Or") {
@@ -229,6 +292,19 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1))
assert(v1Filter.toV2.equals(predicate1))
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
+
+ val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
+ val predicate3 = new Or(left,
+ new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))))
+ assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left))
+
+ val predicate4 = new Or(
+ new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))))
+ assert(PredicateUtils.toV1(predicate4) == None)
}
test("StringStartsWith") {
@@ -243,6 +319,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = StringStartsWith("a", "str")
assert(v1Filter.toV2.equals(predicate1))
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("StringEndsWith") {
@@ -257,6 +336,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = StringEndsWith("a", "str")
assert(v1Filter.toV2.equals(predicate1))
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
test("StringContains") {
@@ -271,6 +353,9 @@ class V2PredicateSuite extends SparkFunSuite {
val v1Filter = StringContains("a", "str")
assert(v1Filter.toV2.equals(predicate1))
+ assert(PredicateUtils.toV1(predicate1).get == v1Filter)
+ assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter)
+ assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1)
}
}