diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6b8127bab1cb..a8d6964b3b83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3385,6 +3385,167 @@ object functions { ArrayExcept(col1.expr, col2.expr) } + private def createLambda(f: Column => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val function = f(Column(x)).expr + LambdaFunction(function, Seq(x)) + } + + private def createLambda(f: (Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val function = f(Column(x), Column(y)).expr + LambdaFunction(function, Seq(x, y)) + } + + private def createLambda(f: (Column, Column, Column) => Column) = { + val x = UnresolvedNamedLambdaVariable(Seq("x")) + val y = UnresolvedNamedLambdaVariable(Seq("y")) + val z = UnresolvedNamedLambdaVariable(Seq("z")) + val function = f(Column(x), Column(y), Column(z)).expr + LambdaFunction(function, Seq(x, y, z)) + } + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def transform(column: Column, f: Column => Column): Column = withExpr { + ArrayTransform(column.expr, createLambda(f)) + } + + /** + * Returns an array of elements after applying a tranformation to each element + * in the input array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def transform(column: Column, f: (Column, Column) => Column): Column = withExpr { + ArrayTransform(column.expr, createLambda(f)) + } + + /** + * Returns whether a predicate holds for one or more elements in the array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def exists(column: Column, f: Column => Column): Column = withExpr { + ArrayExists(column.expr, createLambda(f)) + } + + /** + * Returns whether a predicate holds for every element in the array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def forall(column: Column, f: Column => Column): Column = withExpr { + ArrayForAll(column.expr, createLambda(f)) + } + + /** + * Returns an array of elements for which a predicate holds in a given array. + * + * @group collection_funcs + * @since 3.0.0 + */ + def filter(column: Column, f: Column => Column): Column = withExpr { + ArrayFilter(column.expr, createLambda(f)) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. The final state is converted into the final result + * by applying a finish function. + * + * @group collection_funcs + * @since 3.0.0 + */ + def aggregate( + expr: Column, + zero: Column, + merge: (Column, Column) => Column, + finish: Column => Column): Column = withExpr { + ArrayAggregate( + expr.expr, + zero.expr, + createLambda(merge), + createLambda(finish) + ) + } + + /** + * Applies a binary operator to an initial state and all elements in the array, + * and reduces this to a single state. + * + * @group collection_funcs + * @since 3.0.0 + */ + def aggregate(expr: Column, zero: Column, merge: (Column, Column) => Column): Column = + aggregate(expr, zero, merge, c => c) + + /** + * Merge two given arrays, element-wise, into a signle array using a function. + * If one array is shorter, nulls are appended at the end to match the length of the longer + * array, before applying the function. + * + * @group collection_funcs + * @since 3.0.0 + */ + def zip_with(left: Column, right: Column, f: (Column, Column) => Column): Column = withExpr { + ZipWith(left.expr, right.expr, createLambda(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new keys for the pairs. + * + * @group collection_funcs + * @since 3.0.0 + */ + def transform_keys(expr: Column, f: (Column, Column) => Column): Column = withExpr { + TransformKeys(expr.expr, createLambda(f)) + } + + /** + * Applies a function to every key-value pair in a map and returns + * a map with the results of those applications as the new values for the pairs. + * + * @group collection_funcs + * @since 3.0.0 + */ + def transform_values(expr: Column, f: (Column, Column) => Column): Column = withExpr { + TransformValues(expr.expr, createLambda(f)) + } + + /** + * Returns a map whose key-value pairs satisfy a predicate. + * + * @group collection_funcs + * @since 3.0.0 + */ + def map_filter(expr: Column, f: (Column, Column) => Column): Column = withExpr { + MapFilter(expr.expr, createLambda(f)) + } + + /** + * Merge two given maps, key-wise into a single map using a function. + * + * @group collection_funcs + * @since 3.0.0 + */ + def map_zip_with( + left: Column, + right: Column, + f: (Column, Column, Column) => Column): Column = withExpr { + MapZipWith(left.expr, right.expr, createLambda(f)) + } + /** * Creates a new row for each element in the given array or map column. * Uses the default column name `col` for elements in the array and diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java new file mode 100644 index 000000000000..a5f11d57f3ce --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaHigherOrderFunctionsSuite.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.HashMap; +import java.util.List; + +import static scala.collection.JavaConverters.mapAsScalaMap; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.*; +import static org.apache.spark.sql.types.DataTypes.*; +import static org.apache.spark.sql.functions.*; +import org.apache.spark.sql.test.TestSparkSession; +import static test.org.apache.spark.sql.JavaTestUtils.*; + +public class JavaHigherOrderFunctionsSuite { + private transient TestSparkSession spark; + private Dataset arrDf; + private Dataset mapDf; + + private void setUpArrDf() { + List data = toRows( + makeArray(1, 9, 8, 7), + makeArray(5, 8, 9, 7, 2), + JavaTestUtils.makeArray(), + null + ); + StructType schema = new StructType() + .add("x", new ArrayType(IntegerType, true), true); + arrDf = spark.createDataFrame(data, schema); + } + + private void setUpMapDf() { + List data = toRows( + new HashMap() {{ + put(1, 1); + put(2, 2); + }}, + null + ); + StructType schema = new StructType() + .add("x", new MapType(IntegerType, IntegerType, true)); + mapDf = spark.createDataFrame(data, schema); + } + + @Before + public void setUp() { + spark = new TestSparkSession(); + setUpArrDf(); + setUpMapDf(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @Test + public void testTransform() { + checkAnswer( + arrDf.select(transform(col("x"), x -> x.plus(1))), + toRows( + makeArray(2, 10, 9, 8), + makeArray(6, 9, 10, 8, 3), + JavaTestUtils.makeArray(), + null + ) + ); + checkAnswer( + arrDf.select(transform(col("x"), (x, i) -> x.plus(i))), + toRows( + makeArray(1, 10, 10, 10), + makeArray(5, 9, 11, 10, 6), + JavaTestUtils.makeArray(), + null + ) + ); + } + + @Test + public void testFilter() { + checkAnswer( + arrDf.select(filter(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + makeArray(9), + makeArray(9), + JavaTestUtils.makeArray(), + null + ) + ); + } + + @Test + public void testExists() { + checkAnswer( + arrDf.select(exists(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + true, + true, + false, + null + ) + ); + } + + @Test + public void testForall() { + checkAnswer( + arrDf.select(forall(col("x"), x -> x.plus(1).equalTo(10))), + toRows( + false, + false, + true, + null + ) + ); + } + + @Test + public void testAggregate() { + checkAnswer( + arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x))), + toRows( + 25, + 31, + 0, + null + ) + ); + checkAnswer( + arrDf.select(aggregate(col("x"), lit(0), (acc, x) -> acc.plus(x), x -> x)), + toRows( + 25, + 31, + 0, + null + ) + ); + } + + @Test + public void testZipWith() { + checkAnswer( + arrDf.select(zip_with(col("x"), col("x"), (a, b) -> lit(42))), + toRows( + makeArray(42, 42, 42, 42), + makeArray(42, 42, 42, 42, 42), + JavaTestUtils.makeArray(), + null + ) + ); + } + + @Test + public void testTransformKeys() { + checkAnswer( + mapDf.select(transform_keys(col("x"), (k, v) -> k.plus(v))), + toRows( + mapAsScalaMap(new HashMap() {{ + put(2, 1); + put(4, 2); + }}), + null + ) + ); + } + + @Test + public void testTransformValues() { + checkAnswer( + mapDf.select(transform_values(col("x"), (k, v) -> k.plus(v))), + toRows( + mapAsScalaMap(new HashMap() {{ + put(1, 2); + put(2, 4); + }}), + null + ) + ); + } + + @Test + public void testMapFilter() { + checkAnswer( + mapDf.select(map_filter(col("x"), (k, v) -> lit(false))), + toRows( + mapAsScalaMap(new HashMap()), + null + ) + ); + } + + @Test + public void testMapZipWith() { + checkAnswer( + mapDf.select(map_zip_with(col("x"), col("x"), (k, v1, v2) -> lit(false))), + toRows( + mapAsScalaMap(new HashMap() {{ + put(1, false); + put(2, false); + }}), + null + ) + ); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java new file mode 100644 index 000000000000..7fc6460e7352 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaTestUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import java.util.Arrays; +import java.util.List; +import static java.util.stream.Collectors.toList; + +import scala.collection.mutable.WrappedArray; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; + +public class JavaTestUtils { + public static void checkAnswer(Dataset actual, List expected) { + assertEquals(expected, actual.collectAsList()); + } + + public static List toRows(Object... objs) { + return Arrays.asList(objs) + .stream() + .map(RowFactory::create) + .collect(toList()); + } + + public static WrappedArray makeArray(T... ts) { + return WrappedArray.make(ts); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7d044638db57..dbe72b64b4d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1930,6 +1930,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq(5, 9, 11, 10, 6)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(col("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, 8)), + Row(Seq(6, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, 10)), + Row(Seq(5, 9, 11, 10, 6)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1960,6 +1972,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq(5, null, 10, 12, 11, 7)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(col("i"), x => x + 1)), + Seq( + Row(Seq(2, 10, 9, null, 8)), + Row(Seq(6, null, 9, 10, 8, 3)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("i"), (x, i) => x + i)), + Seq( + Row(Seq(1, 10, 10, null, 11)), + Row(Seq(5, null, 10, 12, 11, 7)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -1990,6 +2014,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("b0", null, "c2", null)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(col("s"), x => concat(x, x))), + Seq( + Row(Seq("cc", "aa", "bb")), + Row(Seq("bb", null, "cc", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("s"), (x, i) => concat(x, i))), + Seq( + Row(Seq("c0", "a1", "b2")), + Row(Seq("b0", null, "c2", null)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2034,6 +2070,32 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq("b", null, "c", null, null))), Row(Seq.empty), Row(null))) + checkAnswer(df.select(transform(col("arg"), arg => arg)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", null, "c", null)), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("arg"), _ => col("arg"))), + Seq( + Row(Seq(Seq("c", "a", "b"), Seq("c", "a", "b"), Seq("c", "a", "b"))), + Row(Seq( + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null), + Seq("b", null, "c", null))), + Row(Seq.empty), + Row(null))) + checkAnswer(df.select(transform(col("arg"), x => concat(col("arg"), array(x)))), + Seq( + Row(Seq(Seq("c", "a", "b", "c"), Seq("c", "a", "b", "a"), Seq("c", "a", "b", "b"))), + Row(Seq( + Seq("b", null, "c", null, "b"), + Seq("b", null, "c", null, null), + Seq("b", null, "c", null, "c"), + Seq("b", null, "c", null, null))), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2080,6 +2142,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), Row(Map(1 -> 10), Map(3 -> -3)))) + checkAnswer(dfInts.select( + map_filter(col("m"), (k, v) => k * 10 === v), + map_filter(col("m"), (k, v) => k === (v * -1))), + Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()), + Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)), + Row(Map(1 -> 10), Map(3 -> -3)))) + val dfComplex = Seq( Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))), Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m") @@ -2090,6 +2160,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), Row(Map(), Map(2 -> Seq(-2, -2))))) + checkAnswer(dfComplex.select( + map_filter(col("m"), (k, v) => k === element_at(v, 1)), + map_filter(col("m"), (k, v) => k === size(v))), + Seq( + Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))), + Row(Map(), Map(2 -> Seq(-2, -2))))) + // Invalid use cases val df = Seq( (Map(1 -> "a"), 1), @@ -2112,6 +2189,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex3a = intercept[AnalysisException] { + df.select(map_filter(col("i"), (k, v) => k > v)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires map type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_filter(a, (k, v) -> k > v)") } @@ -2133,6 +2215,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2157,6 +2245,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq(8, 2)), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(col("i"), _ % 2 === 0)), + Seq( + Row(Seq(8)), + Row(Seq(8, 2)), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2181,6 +2275,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("b", "c")), Row(Seq.empty), Row(null))) + checkAnswer(df.select(filter(col("s"), x => x.isNotNull)), + Seq( + Row(Seq("c", "a", "b")), + Row(Seq("b", "c")), + Row(Seq.empty), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2208,11 +2308,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[AnalysisException] { + df.select(filter(col("i"), x => x)) + } + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("filter(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(filter(col("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("filter(a, x -> x)") } @@ -2234,6 +2344,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(false), Row(false), Row(null))) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2260,6 +2376,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null), Row(false), Row(null))) + checkAnswer(df.select(exists(col("i"), _ % 2 === 0)), + Seq( + Row(true), + Row(false), + Row(null), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2284,6 +2407,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(true), Row(false), Row(null))) + checkAnswer(df.select(exists(col("s"), x => x.isNull)), + Seq( + Row(false), + Row(true), + Row(false), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2311,11 +2440,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[AnalysisException] { + df.select(exists(col("i"), x => x)) + } + assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("exists(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(exists(df("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("exists(a, x -> x)") } @@ -2337,6 +2476,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2363,6 +2508,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => (x % 2 === 0) || x.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(true), + Row(null))) checkAnswer(df.selectExpr("forall(i, x -> x % 2 == 0)"), Seq( Row(false), @@ -2370,6 +2522,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("i"), x => x % 2 === 0)), + Seq( + Row(false), + Row(null), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2394,6 +2553,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(true), Row(true), Row(null))) + checkAnswer(df.select(forall(col("s"), _.isNull)), + Seq( + Row(false), + Row(true), + Row(true), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2421,15 +2586,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex2a = intercept[AnalysisException] { + df.select(forall(col("i"), x => x)) + } + assert(ex2a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3 = intercept[AnalysisException] { df.selectExpr("forall(s, x -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex3a = intercept[AnalysisException] { + df.select(forall(col("s"), x => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 2 requires boolean type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("forall(a, x -> x)") } assert(ex4.getMessage.contains("cannot resolve '`a`'")) + + val ex4a = intercept[AnalysisException] { + df.select(forall(col("a"), x => x)) + } + assert(ex4a.getMessage.contains("cannot resolve '`a`'")) } test("aggregate function - array for primitive type not containing null") { @@ -2453,6 +2633,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(310), Row(0), Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(31), + Row(0), + Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x, _ * 10)), + Seq( + Row(250), + Row(310), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2484,6 +2676,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(0), Row(0), Row(null))) + checkAnswer(df.select(aggregate(col("i"), lit(0), (acc, x) => acc + x)), + Seq( + Row(25), + Row(null), + Row(0), + Row(null))) + checkAnswer( + df.select( + aggregate(col("i"), lit(0), (acc, x) => acc + x, acc => coalesce(acc, lit(0)) * 10)), + Seq( + Row(250), + Row(0), + Row(0), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2515,6 +2721,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(""), Row("c"), Row(null))) + checkAnswer(df.select(aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x))), + Seq( + Row("acab"), + Row(null), + Row("c"), + Row(null))) + checkAnswer( + df.select( + aggregate(col("ss"), col("s"), (acc, x) => concat(acc, x), + acc => coalesce(acc, lit("")))), + Seq( + Row("acab"), + Row(""), + Row("c"), + Row(null))) } // Test with local relation, the Project will be evaluated without codegen @@ -2547,11 +2768,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[AnalysisException] { + df.select(aggregate(col("i"), lit(0), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("aggregate(s, 0, (acc, x) -> x)") } assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex4a = intercept[AnalysisException] { + df.select(aggregate(col("s"), lit(0), (acc, x) => x)) + } + assert(ex4a.getMessage.contains("data type mismatch: argument 3 requires int type")) + val ex5 = intercept[AnalysisException] { df.selectExpr("aggregate(a, 0, (acc, x) -> x)") } @@ -2572,6 +2803,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map(10 -> null, 8 -> false, 4 -> null)), Row(Map(5 -> null)), Row(null))) + + checkAnswer(df.select(map_zip_with(df("m1"), df("m2"), (k, v1, v2) => k === v1 + v2)), + Seq( + Row(Map(8 -> true, 3 -> false, 6 -> true)), + Row(Map(10 -> null, 8 -> false, 4 -> null)), + Row(Map(5 -> null)), + Row(null))) } test("map_zip_with function - map of non-primitive types") { @@ -2588,6 +2826,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), Row(Map("a" -> Row("d", null))), Row(null))) + + checkAnswer(df.select(map_zip_with(col("m1"), col("m2"), (k, v1, v2) => struct(v1, v2))), + Seq( + Row(Map("z" -> Row("a", "c"), "y" -> Row("b", null), "x" -> Row("c", "a"))), + Row(Map("b" -> Row("a", null), "c" -> Row("d", "a"), "d" -> Row(null, "k"))), + Row(Map("a" -> Row("d", null))), + Row(null))) } test("map_zip_with function - invalid") { @@ -2606,16 +2851,32 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(ex2.getMessage.contains("The input to function map_zip_with should have " + "been two maps with compatible key types")) + val ex2a = intercept[AnalysisException] { + df.select(map_zip_with(df("mis"), col("mmi"), (x, y, z) => concat(x, y, z))) + } + assert(ex2a.getMessage.contains("The input to function map_zip_with should have " + + "been two maps with compatible key types")) + val ex3 = intercept[AnalysisException] { df.selectExpr("map_zip_with(i, mis, (x, y, z) -> concat(x, y, z))") } assert(ex3.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex3a = intercept[AnalysisException] { + df.select(map_zip_with(col("i"), col("mis"), (x, y, z) => concat(x, y, z))) + } + assert(ex3a.getMessage.contains("type mismatch: argument 1 requires map type")) + val ex4 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mis, i, (x, y, z) -> concat(x, y, z))") } assert(ex4.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex4a = intercept[AnalysisException] { + df.select(map_zip_with(col("mis"), col("i"), (x, y, z) => concat(x, y, z))) + } + assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) + val ex5 = intercept[AnalysisException] { df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") } @@ -2644,27 +2905,59 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"), Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample1.select(transform_keys(col("i"), (k, v) => k + v)), + Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, " + "(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"), Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.select( + transform_keys( + col("j"), + (k, v) => element_at( + map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ), + k + ) + ) + ), + Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"), Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(col("j"), + (k, v) => (v * 2).cast("bigint") + k)), + Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7)))) + checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), + Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), + Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false)))) - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), + checkAnswer(dfExample3.select(transform_keys(col("x"), + (k, v) => when(v, k * 2).otherwise(k * 3))), Seq(Row(Map(50 -> true, 78 -> false)))) checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"), Seq(Row(Map(false -> false)))) + + checkAnswer(dfExample4.select(transform_keys(col("y"), + (k, v) => array_contains(k, lit(3)) && v)), + Seq(Row(Map(false -> false)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2702,6 +2995,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3.getMessage.contains("Cannot use null as map key")) + val ex3a = intercept[Exception] { + dfExample1.select(transform_keys(col("i"), (k, v) => v)).show() + } + assert(ex3a.getMessage.contains("Cannot use null as map key")) + val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)") } @@ -2766,6 +3064,46 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k + v)), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.select( + transform_values(col("x"), (k, v) => when(k, v).otherwise(k.cast("string")))), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.select(transform_values(col("x"), + (k, v) => (!k) && v === "abc")), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.select(transform_values(col("y"), (k, v) => v * v)), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.select( + transform_values(col("y"), (k, v) => concat(k, lit(":"), v.cast("string")))), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.select(transform_values(col("y"), (k, v) => concat(k, v.cast("string")))), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + val testMap = map_from_arrays( + array(lit(1), lit(2), lit(3)), + array(lit("one"), lit("two"), lit("three")) + ) + + checkAnswer( + dfExample4.select(transform_values(col("z"), + (k, v) => concat(element_at(testMap, k), lit("_"), v.cast("string")))), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.select(transform_values(col("z"), (k, v) => k - v)), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.select(transform_values(col("c"), (k, v) => k + size(v))), + Seq(Row(Map(1 -> 3)))) } // Test with local relation, the Project will be evaluated without codegen @@ -2809,6 +3147,28 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => k)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v)), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(0))), + Seq(Row(Map.empty[Integer, Integer]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit("value"))), + Seq(Row(Map.empty[Integer, String]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => lit(true))), + Seq(Row(Map.empty[Integer, Boolean]))) + + checkAnswer(dfExample1.select(transform_values(col("i"), (k, v) => v.cast("bigint"))), + Seq(Row(Map.empty[BigInt, BigInt]))) } testEmpty() @@ -2833,6 +3193,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample2.selectExpr( "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + + checkAnswer(dfExample1.select(transform_values(col("a"), + (k, v) => lit(null).cast("int"))), + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.select( + transform_values(col("b"), (k, v) => when(v.isNull, k + 1).otherwise(k + 2)) + ), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) } testNullValue() @@ -2871,6 +3240,12 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { } assert(ex3.getMessage.contains( "data type mismatch: argument 1 requires map type")) + + val ex3a = intercept[AnalysisException] { + dfExample3.select(transform_values(col("x"), (k, v) => k + 1)) + } + assert(ex3a.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() @@ -2897,10 +3272,15 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq.empty), Row(null)) checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + checkAnswer(df1.select(zip_with(df1("val1"), df1("val2"), (x, y) => x + y)), expectedValue1) val expectedValue2 = Seq( Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + checkAnswer( + df2.select(zip_with(df2("val1"), df2("val2"), (x, y) => struct(y, x))), + expectedValue2 + ) } test("arrays zip_with function - for non-primitive types") { @@ -2915,7 +3295,14 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq(Row("x", "a"), Row("y", null))), Row(Seq.empty), Row(null)) - checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + checkAnswer( + df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), + expectedValue1 + ) + checkAnswer( + df.select(zip_with(col("val1"), col("val2"), (x, y) => struct(y, x))), + expectedValue1 + ) } test("arrays zip_with function - invalid") { @@ -2937,6 +3324,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.selectExpr("zip_with(i, a2, (acc, x) -> x)") } assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex3a = intercept[AnalysisException] { + df.select(zip_with(df("i"), df("a2"), (acc, x) => x)) + } + assert(ex3a.getMessage.contains("data type mismatch: argument 1 requires array type")) val ex4 = intercept[AnalysisException] { df.selectExpr("zip_with(a1, a, (acc, x) -> x)") }