diff --git a/pom.xml b/pom.xml index 5e29ed5792518..acebec1e2ec3e 100644 --- a/pom.xml +++ b/pom.xml @@ -128,7 +128,7 @@ 2.3 - 3.2.0 + 3.2.1 10.14.2.0 1.12.3 diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 465541abdaa9e..3953ab30c1ed5 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -7766,6 +7766,7 @@ def nsmallest( a 1 d 2 """ + assert type(n) is int by_scols = self._prepare_sort_by_scols(columns) return self._sort(by=by_scols, ascending=True, na_position="last", keep=keep).head(n=n) diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py index 4377ad6a5c91e..416411f6dd3f0 100644 --- a/python/pyspark/pandas/groupby.py +++ b/python/pyspark/pandas/groupby.py @@ -3574,6 +3574,7 @@ def nsmallest(self, n: int = 5) -> Series: 3 6 3 Name: b, dtype: int64 """ + assert type(n) is int if self._psser._internal.index_level > 1: raise ValueError("nsmallest do not support multi-index now") diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index ff4c7fcc8f140..de7f383b43a4c 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -3417,6 +3417,7 @@ def nsmallest(self, n: int = 5) -> "Series": 2 3.0 dtype: float64 """ + assert type(n) is int return self.sort_values(ascending=True).head(n) def nlargest(self, n: int = 5) -> "Series": diff --git a/python/pyspark/pandas/tests/test_dataframe.py b/python/pyspark/pandas/tests/test_dataframe.py index 1361c44404a3d..f7d055130b41c 100644 --- a/python/pyspark/pandas/tests/test_dataframe.py +++ b/python/pyspark/pandas/tests/test_dataframe.py @@ -1907,6 +1907,10 @@ def test_nsmallest(self): msg = 'keep must be either "first", "last" or "all".' with self.assertRaisesRegex(ValueError, msg): psdf.nlargest(5, columns=["c"], keep="xx") + with self.assertRaises(AssertionError): + psdf.nsmallest("test", columns=["c"], keep="last") + with self.assertRaises(AssertionError): + psdf.nsmallest(0.1, columns=["c"], keep="last") def test_xs(self): d = { diff --git a/python/pyspark/pandas/tests/test_groupby.py b/python/pyspark/pandas/tests/test_groupby.py index cff2ce706d8cb..8b1457aa6fc4b 100644 --- a/python/pyspark/pandas/tests/test_groupby.py +++ b/python/pyspark/pandas/tests/test_groupby.py @@ -1765,6 +1765,10 @@ def test_nsmallest(self): ) with self.assertRaisesRegex(ValueError, "nsmallest do not support multi-index now"): psdf.set_index(["a", "b"]).groupby(["c"])["d"].nsmallest(1) + with self.assertRaises(AssertionError): + psdf.groupby(["a"])["b"].nsmallest(0.1).sort_index() + with self.assertRaises(AssertionError): + psdf.groupby(["a"])["b"].nsmallest(False).sort_index() def test_nlargest(self): pdf = pd.DataFrame( diff --git a/python/pyspark/pandas/tests/test_series.py b/python/pyspark/pandas/tests/test_series.py index 144df0f986a70..2ab4d8d46c141 100644 --- a/python/pyspark/pandas/tests/test_series.py +++ b/python/pyspark/pandas/tests/test_series.py @@ -877,6 +877,11 @@ def test_nsmallest(self): self.assert_eq(psser.nsmallest(), pser.nsmallest()) self.assert_eq((psser + 1).nsmallest(), (pser + 1).nsmallest()) + with self.assertRaises(AssertionError): + psser.nsmallest("String") + with self.assertRaises(AssertionError): + psser.nsmallest(0.1) + def test_nlargest(self): sample_lst = [1, 2, 3, 4, np.nan, 6] pser = pd.Series(sample_lst, name="x") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index ace6b30d4ccec..263edd82197bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -19,14 +19,25 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} -import org.apache.spark.sql.connector.expressions.filter.Predicate -import org.apache.spark.sql.sources.{Filter, In} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} +import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} +import org.apache.spark.sql.types.StringType private[sql] object PredicateUtils { def toV1(predicate: Predicate): Option[Filter] = { + + def isValidBinaryPredicate(): Boolean = { + if (predicate.children().length == 2 && + predicate.children()(0).isInstanceOf[NamedReference] && + predicate.children()(1).isInstanceOf[LiteralValue[_]]) { + true + } else { + false + } + } + predicate.name() match { - // TODO: add conversion for other V2 Predicate case "IN" if predicate.children()(0).isInstanceOf[NamedReference] => val attribute = predicate.children()(0).toString val values = predicate.children().drop(1) @@ -43,6 +54,81 @@ private[sql] object PredicateUtils { Some(In(attribute, Array.empty[Any])) } + case "=" | "<=>" | ">" | "<" | ">=" | "<=" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + val v1Value = CatalystTypeConverters.convertToScala(value.value, value.dataType) + val v1Filter = predicate.name() match { + case "=" => EqualTo(attribute, v1Value) + case "<=>" => EqualNullSafe(attribute, v1Value) + case ">" => GreaterThan(attribute, v1Value) + case ">=" => GreaterThanOrEqual(attribute, v1Value) + case "<" => LessThan(attribute, v1Value) + case "<=" => LessThanOrEqual(attribute, v1Value) + } + Some(v1Filter) + + case "IS_NULL" | "IS_NOT_NULL" if predicate.children().length == 1 && + predicate.children()(0).isInstanceOf[NamedReference] => + val attribute = predicate.children()(0).toString + val v1Filter = predicate.name() match { + case "IS_NULL" => IsNull(attribute) + case "IS_NOT_NULL" => IsNotNull(attribute) + } + Some(v1Filter) + + case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate => + val attribute = predicate.children()(0).toString + val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] + if (!value.dataType.sameType(StringType)) return None + val v1Value = value.value.toString + val v1Filter = predicate.name() match { + case "STARTS_WITH" => + StringStartsWith(attribute, v1Value) + case "ENDS_WITH" => + StringEndsWith(attribute, v1Value) + case "CONTAINS" => + StringContains(attribute, v1Value) + } + Some(v1Filter) + + case "ALWAYS_TRUE" | "ALWAYS_FALSE" if predicate.children().isEmpty => + val v1Filter = predicate.name() match { + case "ALWAYS_TRUE" => AlwaysTrue() + case "ALWAYS_FALSE" => AlwaysFalse() + } + Some(v1Filter) + + case "AND" => + val and = predicate.asInstanceOf[V2And] + val left = toV1(and.left()) + val right = toV1(and.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(And(left.get, right.get)) + } else { + None + } + + case "OR" => + val or = predicate.asInstanceOf[V2Or] + val left = toV1(or.left()) + val right = toV1(or.right()) + if (left.nonEmpty && right.nonEmpty) { + Some(Or(left.get, right.get)) + } else if (left.nonEmpty) { + left + } else { + right + } + + case "NOT" => + val child = toV1(predicate.asInstanceOf[V2Not].child()) + if (child.nonEmpty) { + Some(Not(child.get)) + } else { + None + } + case _ => None } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala index 2df8b8e56c44b..de556c50f5d4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} import org.apache.spark.sql.connector.expressions.filter._ import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.internal.connector.PredicateUtils import org.apache.spark.sql.sources.{AlwaysFalse => V1AlwaysFalse, AlwaysTrue => V1AlwaysTrue, And => V1And, EqualNullSafe, EqualTo, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not => V1Not, Or => V1Or, StringContains, StringEndsWith, StringStartsWith} import org.apache.spark.sql.types.{IntegerType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -34,6 +35,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate1.describe.equals("a.B = 1")) val v1Filter1 = EqualTo(ref("a", "B").describe(), 1) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val predicate2 = new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) @@ -41,6 +45,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate2.describe.equals("a.`b.c` = 1")) val v1Filter2 = EqualTo(ref("a", "b.c").describe(), 1) assert(v1Filter2.toV2 == predicate2) + assert(PredicateUtils.toV1(predicate2).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate2).get.toV2 == predicate2) val predicate3 = new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) @@ -48,6 +55,9 @@ class V2PredicateSuite extends SparkFunSuite { assert(predicate3.describe.equals("```a``.b`.c = 1")) val v1Filter3 = EqualTo(ref("`a`.b", "c").describe(), 1) assert(v1Filter3.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter3) + assert(PredicateUtils.toV1(v1Filter3.toV2).get == v1Filter3) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("AlwaysTrue") { @@ -59,6 +69,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysTrue assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("AlwaysFalse") { @@ -70,6 +83,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1AlwaysFalse assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualTo") { @@ -81,6 +97,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualTo("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("EqualNullSafe") { @@ -92,6 +111,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = EqualNullSafe("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThan") { @@ -103,6 +125,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("LessThanOrEqual") { @@ -114,6 +139,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = LessThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThan") { @@ -125,6 +153,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThan("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("GreatThanOrEqual") { @@ -136,6 +167,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = GreaterThanOrEqual("a", 1) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("In") { @@ -161,9 +195,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter1 = In("a", Array(1, 2, 3, 4)) assert(v1Filter1.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter1) + assert(PredicateUtils.toV1(v1Filter1.toV2).get == v1Filter1) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) val v1Filter2 = In("a", values.map(_.value())) assert(v1Filter2.toV2 == predicate3) + assert(PredicateUtils.toV1(predicate3).get == v1Filter2) + assert(PredicateUtils.toV1(v1Filter2.toV2).get == v1Filter2) + assert(PredicateUtils.toV1(predicate3).get.toV2 == predicate3) } test("IsNull") { @@ -175,6 +215,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("IsNotNull") { @@ -186,6 +229,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = IsNotNull("a") assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("Not") { @@ -199,6 +245,14 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Not(LessThan("a", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new Not( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("And") { @@ -214,6 +268,15 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2 == predicate1) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val predicate3 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType), + LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == None) } test("Or") { @@ -229,6 +292,19 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1)) assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) + + val left = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate3 = new Or(left, + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate3) == PredicateUtils.toV1(left)) + + val predicate4 = new Or( + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](LiteralValue(1, IntegerType)))) + assert(PredicateUtils.toV1(predicate4) == None) } test("StringStartsWith") { @@ -243,6 +319,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringStartsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringEndsWith") { @@ -257,6 +336,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringEndsWith("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } test("StringContains") { @@ -271,6 +353,9 @@ class V2PredicateSuite extends SparkFunSuite { val v1Filter = StringContains("a", "str") assert(v1Filter.toV2.equals(predicate1)) + assert(PredicateUtils.toV1(predicate1).get == v1Filter) + assert(PredicateUtils.toV1(v1Filter.toV2).get == v1Filter) + assert(PredicateUtils.toV1(predicate1).get.toV2 == predicate1) } }