diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0ee30e1b0800..2ba34647dbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3459,6 +3459,16 @@ object functions { ArrayFilter(column.expr, createLambda(f)) } + /** + * Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def filter(column: Column, f: (Column, Column) => Column): Column = withExpr { + ArrayFilter(column.expr, createLambda(f)) + } + /** * Applies a binary operator to an initial state and all elements in the array, * and reduces this to a single state. The final state is converted into the final result diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java index a5f11d57f3ce..e240326bee63 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -110,6 +110,15 @@ public void testFilter() { null ) ); + checkAnswer( + arrDf.select(filter(col("x"), (x, i) -> x.plus(i).equalTo(10))), + toRows( + makeArray(9, 8, 7), + makeArray(7), + JavaTestUtils.makeArray(), + null + ) + ); } @Test diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 5526279b767b..06484908f5e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2305,6 +2305,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(col("s"), (x, i) => i % 2 === 0)), + Seq( + Row(Seq("c", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen