diff --git a/pom.xml b/pom.xml index 5e29ed5792518..acebec1e2ec3e 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 2.3 - 3.2.0 + 3.2.1 10.14.2.0 1.12.3 diff --git a/python/pyspark/pandas/base.py b/python/pyspark/pandas/base.py index 3430f5efa93ee..bf7149e6b2312 100644 --- a/python/pyspark/pandas/base.py +++ b/python/pyspark/pandas/base.py @@ -1179,6 +1179,9 @@ def _shift( if not isinstance(periods, int): raise TypeError("periods should be an int; however, got [%s]" % type(periods).__name__) + if periods == 0: + return self.copy() + col = self.spark.column window = ( Window.partitionBy(*part_cols) diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 1361c44404a3d..add93faba0c8f 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -4249,6 +4249,7 @@ def test_shift(self): psdf.columns = columns self.assert_eq(pdf.shift(3), psdf.shift(3)) self.assert_eq(pdf.shift().shift(-1), psdf.shift().shift(-1)) + self.assert_eq(pdf.shift(0), psdf.shift(0)) def test_diff(self): pdf = pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 144df0f986a70..6bc07def712cc 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -1549,6 +1549,8 @@ def test_shift(self): with self.assertRaisesRegex(TypeError, "periods should be an int; however"): psser.shift(periods=1.5) + self.assert_eq(psser.shift(periods=0), pser.shift(periods=0)) + def test_diff(self): pser = pd.Series([10, 20, 15, 30, 45], name="x") psser = ps.Series(pser) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index ace6b30d4ccec..263edd82197bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,14 +19,25 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if (predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]]) { + true + } else { + false + } + } + predicate.name() match { - // TODO: add conversion for other V2 Predicate case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString val values = predicate.children().drop(1) @@ -43,6 +54,81 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) + } + Some(v1Filter) + + case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) + } + Some(v1Filter) + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + val v1Value = value.value.toString + val v1Filter = predicate.name() match { + case "STARTS_WITH" => + StringStartsWith(attribute, v1Value) + case "ENDS_WITH" => + StringEndsWith(attribute, v1Value) + case "CONTAINS" => + StringContains(attribute, v1Value) + } + Some(v1Filter) + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() + } + Some(v1Filter) + + case "AND" => + val and = predicate.asInstanceOf[V2And] + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left + } else { + right + } + + case "NOT" => + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 2df8b8e56c44b..de556c50f5d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} import org.apache.spark.sql.connector.expressions.filter._ import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue => V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate1.describe.equals("a.B = 1")) val v1Filter1 = EqualTo(ref("a", "B").describe(), 1) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val predicate2 = new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) @@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate2.describe.equals("a.`b.c` = 1")) val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1) assert(v1Filter2.toV2 == predicate2) + assert(PredicateUtils.toV1(predicate2).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2) val predicate3 = new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) @@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate3.describe.equals("```a``.b`.c = 1")) val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1) assert(v1Filter3.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter3) + assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("AlwaysTrue") { @@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysTrue assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("AlwaysFalse") { @@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysFalse assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualTo") { @@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualTo("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualNullSafe") { @@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualNullSafe("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThan") { @@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThanOrEqual") { @@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThan") { @@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThanOrEqual") { @@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("In") { @@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter1 = In("a", Array(1, 2, 3, 4)) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val v1Filter2 = In("a", values.map(_.value())) assert(v1Filter2.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("IsNull") { @@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("IsNotNull") { @@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNotNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("Not") { @@ -199,6 +245,14 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Not(LessThan("a", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new Not( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("And") { @@ -214,6 +268,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("Or") { @@ -229,6 +292,19 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate3 = new Or(left, + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left)) + + val predicate4 = new Or( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate4) == None) } test("StringStartsWith") { @@ -243,6 +319,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringEndsWith") { @@ -257,6 +336,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringContains") { @@ -271,6 +353,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } }