diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b993e1a9bad6..061336455189 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -448,6 +448,8 @@ object FunctionRegistry { expression[ArrayAggregate]("aggregate"), expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), + expression[ZipWith]("zip_with"), + CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index a305a05add7a..9d603d79eedc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -740,3 +740,79 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def prettyName: String = "map_zip_with" } + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); + array(('a', 1), ('b', 3), ('c', 5)) + > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); + array(4, 6) + > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); + array('ad', 'be', 'cf') + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class ZipWith(left: Expression, right: Expression, function: Expression) + extends HigherOrderFunction with CodegenFallback { + + def functionForEval: Expression = functionsForEval.head + + override def arguments: Seq[Expression] = left :: right :: Nil + + override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil + + override def functions: Seq[Expression] = List(function) + + override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil + + override def nullable: Boolean = left.nullable || right.nullable + + override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = { + val ArrayType(leftElementType, _) = left.dataType + val ArrayType(rightElementType, _) = right.dataType + copy(function = f(function, + (leftElementType, true) :: (rightElementType, true) :: Nil)) + } + + @transient lazy val LambdaFunction(_, + Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function + + override def eval(input: InternalRow): Any = { + val leftArr = left.eval(input).asInstanceOf[ArrayData] + if (leftArr == null) { + null + } else { + val rightArr = right.eval(input).asInstanceOf[ArrayData] + if (rightArr == null) { + null + } else { + val resultLength = math.max(leftArr.numElements(), rightArr.numElements()) + val f = functionForEval + val result = new GenericArrayData(new Array[Any](resultLength)) + var i = 0 + while (i < resultLength) { + if (i < leftArr.numElements()) { + leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType)) + } else { + leftElemVar.value.set(null) + } + if (i < rightArr.numElements()) { + rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType)) + } else { + rightElemVar.value.set(null) + } + result.update(i, f.eval(input)) + i += 1 + } + result + } + } + } + + override def prettyName: String = "zip_with" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 12ef01816835..3a78f14c8b2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -471,4 +471,52 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper map_zip_with(mbb0, mbbn, concat), null) } + + test("ZipWith") { + def zip_with( + left: Expression, + right: Expression, + f: (Expression, Expression) => Expression): Expression = { + val ArrayType(leftT, _) = left.dataType + val ArrayType(rightT, _) = right.dataType + ZipWith(left, right, createLambda(leftT, true, rightT, true, f)) + } + + val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) + val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false)) + val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) + val ai3 = Literal.create(Seq[Integer](1, null), ArrayType(IntegerType, containsNull = true)) + val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false)) + + val add: (Expression, Expression) => Expression = (x, y) => x + y + val plusOne: Expression => Expression = x => x + 1 + + checkEvaluation(zip_with(ai0, ai1, add), Seq(2, 4, 6, null)) + checkEvaluation(zip_with(ai3, ai2, add), Seq(2, null, null)) + checkEvaluation(zip_with(ai2, ai3, add), Seq(2, null, null)) + checkEvaluation(zip_with(ain, ain, add), null) + checkEvaluation(zip_with(ai1, ain, add), null) + checkEvaluation(zip_with(ain, ai1, add), null) + + val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false)) + val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true)) + val as2 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = true)) + val asn = Literal.create(null, ArrayType(StringType, containsNull = false)) + + val concat: (Expression, Expression) => Expression = (x, y) => Concat(Seq(x, y)) + + checkEvaluation(zip_with(as0, as1, concat), Seq("aa", null, "cc")) + checkEvaluation(zip_with(as0, as2, concat), Seq("aa", null, null)) + + val aai1 = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + val aai2 = Literal.create(Seq(Seq(1, 2, 3)), + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)) + checkEvaluation( + zip_with(aai1, aai2, (a1, a2) => + Cast(zip_with(transform(a1, plusOne), transform(a2, plusOne), add), StringType)), + Seq("[4, 6, 8]", null, null)) + checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, plusOne), StringType)), + Seq("[2, 3, 4]", null, "[5, 6]")) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 9a8454455ae7..05ec5effdf14 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -51,7 +51,16 @@ select exists(ys, y -> y > 30) as v from nested; -- Check for element existence in a null array select exists(cast(null as array), y -> y > 30) as v; - + +-- Zip with array +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested; + +-- Zip with array with concat +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v; + +-- Zip with array coalesce +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v; + create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index b77bda7bb267..5a39616191e8 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -166,37 +166,63 @@ NULL -- !query 17 +select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +[13] +[34,99,null] +[80,-74] + + +-- !query 18 +select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v +-- !query 18 schema +struct> +-- !query 18 output +["ad","be","cf"] + + +-- !query 19 +select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v +-- !query 19 schema +struct> +-- !query 19 output +["a",null,"f"] + + +-- !query 20 create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) --- !query 17 schema +-- !query 20 schema struct<> --- !query 17 output +-- !query 20 output --- !query 18 +-- !query 21 select transform_keys(ys, (k, v) -> k) as v from nested --- !query 18 schema +-- !query 21 schema struct> --- !query 18 output +-- !query 21 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 19 +-- !query 22 select transform_keys(ys, (k, v) -> k + 1) as v from nested --- !query 19 schema +-- !query 22 schema struct> --- !query 19 output +-- !query 22 output {2:1,3:2,4:3} {5:4,6:5,7:6} --- !query 20 +-- !query 23 select transform_keys(ys, (k, v) -> k + v) as v from nested --- !query 20 schema +-- !query 23 schema struct> --- !query 20 output +-- !query 23 output {10:5,12:6,8:4} {2:1,4:2,6:3} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 22f191209f87..9e2bfd3b7fba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2389,6 +2389,69 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { "data type mismatch: argument 1 requires map type")) } + test("arrays zip_with function - for primitive types") { + val df1 = Seq[(Seq[Integer], Seq[Integer])]( + (Seq(9001, 9002, 9003), Seq(4, 5, 6)), + (Seq(1, 2), Seq(3, 4)), + (Seq.empty, Seq.empty), + (null, null) + ).toDF("val1", "val2") + val df2 = Seq[(Seq[Integer], Seq[Long])]( + (Seq(1, null, 3), Seq(1L, 2L)), + (Seq(1, 2, 3), Seq(4L, 11L)) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(9005, 9007, 9009)), + Row(Seq(4, 6)), + Row(Seq.empty), + Row(null)) + checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) + val expectedValue2 = Seq( + Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))), + Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3)))) + checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) + } + + test("arrays zip_with function - for non-primitive types") { + val df = Seq( + (Seq("a"), Seq("x", "y", "z")), + (Seq("a", null), Seq("x", "y")), + (Seq.empty[String], Seq.empty[String]), + (Seq("a", "b", "c"), null) + ).toDF("val1", "val2") + val expectedValue1 = Seq( + Row(Seq(Row("x", "a"), Row("y", null), Row("z", null))), + Row(Seq(Row("x", "a"), Row("y", null))), + Row(Seq.empty), + Row(null)) + checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1) + } + + test("arrays zip_with function - invalid") { + val df = Seq( + (Seq("c", "a", "b"), Seq("x", "y", "z"), 1), + (Seq("b", null, "c", null), Seq("x"), 2), + (Seq.empty, Seq("x", "z"), 3), + (null, Seq("x", "z"), 4) + ).toDF("a1", "a2", "i") + val ex1 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, x -> x)") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + val ex2 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)") + } + assert(ex2.getMessage.contains("Invalid number of arguments for function zip_with")) + val ex3 = intercept[AnalysisException] { + df.selectExpr("zip_with(i, a2, (acc, x) -> x)") + } + assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type")) + val ex4 = intercept[AnalysisException] { + df.selectExpr("zip_with(a1, a, (acc, x) -> x)") + } + assert(ex4.getMessage.contains("cannot resolve '`a`'")) + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {