Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,8 @@ object FunctionRegistry {
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
expression[ZipWith]("zip_with"),

CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,79 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression)

override def prettyName: String = "map_zip_with"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x));
array(('a', 1), ('b', 3), ('c', 5))
> SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y));
array(4, 6)
> SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y));
array('ad', 'be', 'cf')
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ZipWith(left: Expression, right: Expression, function: Expression)
extends HigherOrderFunction with CodegenFallback {

def functionForEval: Expression = functionsForEval.head

override def arguments: Seq[Expression] = left :: right :: Nil

override def argumentTypes: Seq[AbstractDataType] = ArrayType :: ArrayType :: Nil

override def functions: Seq[Expression] = List(function)

override def functionTypes: Seq[AbstractDataType] = AnyDataType :: Nil

override def nullable: Boolean = left.nullable || right.nullable

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ZipWith = {
val ArrayType(leftElementType, _) = left.dataType
val ArrayType(rightElementType, _) = right.dataType
copy(function = f(function,
(leftElementType, true) :: (rightElementType, true) :: Nil))
}

@transient lazy val LambdaFunction(_,
Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function

override def eval(input: InternalRow): Any = {
val leftArr = left.eval(input).asInstanceOf[ArrayData]
if (leftArr == null) {
null
} else {
val rightArr = right.eval(input).asInstanceOf[ArrayData]
if (rightArr == null) {
null
} else {
val resultLength = math.max(leftArr.numElements(), rightArr.numElements())
val f = functionForEval
val result = new GenericArrayData(new Array[Any](resultLength))
var i = 0
while (i < resultLength) {
if (i < leftArr.numElements()) {
leftElemVar.value.set(leftArr.get(i, leftElemVar.dataType))
} else {
leftElemVar.value.set(null)
}
if (i < rightArr.numElements()) {
rightElemVar.value.set(rightArr.get(i, rightElemVar.dataType))
} else {
rightElemVar.value.set(null)
}
result.update(i, f.eval(input))
i += 1
}
result
}
}
}

override def prettyName: String = "zip_with"
}
Original file line number Diff line number Diff line change
Expand Up @@ -471,4 +471,52 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
map_zip_with(mbb0, mbbn, concat),
null)
}

test("ZipWith") {
def zip_with(
left: Expression,
right: Expression,
f: (Expression, Expression) => Expression): Expression = {
val ArrayType(leftT, _) = left.dataType
val ArrayType(rightT, _) = right.dataType
ZipWith(left, right, createLambda(leftT, true, rightT, true, f))
}

val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false))
val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
val ai3 = Literal.create(Seq[Integer](1, null), ArrayType(IntegerType, containsNull = true))
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))

val add: (Expression, Expression) => Expression = (x, y) => x + y
val plusOne: Expression => Expression = x => x + 1

checkEvaluation(zip_with(ai0, ai1, add), Seq(2, 4, 6, null))
checkEvaluation(zip_with(ai3, ai2, add), Seq(2, null, null))
checkEvaluation(zip_with(ai2, ai3, add), Seq(2, null, null))
checkEvaluation(zip_with(ain, ain, add), null)
checkEvaluation(zip_with(ai1, ain, add), null)
checkEvaluation(zip_with(ain, ai1, add), null)

val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
val as2 = Literal.create(Seq("a"), ArrayType(StringType, containsNull = true))
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))

val concat: (Expression, Expression) => Expression = (x, y) => Concat(Seq(x, y))

checkEvaluation(zip_with(as0, as1, concat), Seq("aa", null, "cc"))
checkEvaluation(zip_with(as0, as2, concat), Seq("aa", null, null))

val aai1 = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
val aai2 = Literal.create(Seq(Seq(1, 2, 3)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
checkEvaluation(
zip_with(aai1, aai2, (a1, a2) =>
Cast(zip_with(transform(a1, plusOne), transform(a2, plusOne), add), StringType)),
Seq("[4, 6, 8]", null, null))
checkEvaluation(zip_with(aai1, aai1, (a1, a2) => Cast(transform(a1, plusOne), StringType)),
Seq("[2, 3, 4]", null, "[5, 6]"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;


-- Zip with array
select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested;

-- Zip with array with concat
select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v;

-- Zip with array coalesce
select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v;

create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,37 +166,63 @@ NULL


-- !query 17
select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested
-- !query 17 schema
struct<v:array<int>>
-- !query 17 output
[13]
[34,99,null]
[80,-74]


-- !query 18
select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v
-- !query 18 schema
struct<v:array<string>>
-- !query 18 output
["ad","be","cf"]


-- !query 19
select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v
-- !query 19 schema
struct<v:array<string>>
-- !query 19 output
["a",null,"f"]


-- !query 20
create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys)
-- !query 17 schema
-- !query 20 schema
struct<>
-- !query 17 output
-- !query 20 output


-- !query 18
-- !query 21
select transform_keys(ys, (k, v) -> k) as v from nested
-- !query 18 schema
-- !query 21 schema
struct<v:map<int,int>>
-- !query 18 output
-- !query 21 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 19
-- !query 22
select transform_keys(ys, (k, v) -> k + 1) as v from nested
-- !query 19 schema
-- !query 22 schema
struct<v:map<int,int>>
-- !query 19 output
-- !query 22 output
{2:1,3:2,4:3}
{5:4,6:5,7:6}


-- !query 20
-- !query 23
select transform_keys(ys, (k, v) -> k + v) as v from nested
-- !query 20 schema
-- !query 23 schema
struct<v:map<int,int>>
-- !query 20 output
-- !query 23 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}
Original file line number Diff line number Diff line change
Expand Up @@ -2389,6 +2389,69 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
"data type mismatch: argument 1 requires map type"))
}

test("arrays zip_with function - for primitive types") {
val df1 = Seq[(Seq[Integer], Seq[Integer])](
(Seq(9001, 9002, 9003), Seq(4, 5, 6)),
(Seq(1, 2), Seq(3, 4)),
(Seq.empty, Seq.empty),
(null, null)
).toDF("val1", "val2")
val df2 = Seq[(Seq[Integer], Seq[Long])](
(Seq(1, null, 3), Seq(1L, 2L)),
(Seq(1, 2, 3), Seq(4L, 11L))
).toDF("val1", "val2")
val expectedValue1 = Seq(
Row(Seq(9005, 9007, 9009)),
Row(Seq(4, 6)),
Row(Seq.empty),
Row(null))
checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1)
val expectedValue2 = Seq(
Row(Seq(Row(1L, 1), Row(2L, null), Row(null, 3))),
Row(Seq(Row(4L, 1), Row(11L, 2), Row(null, 3))))
checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2)
}

test("arrays zip_with function - for non-primitive types") {
val df = Seq(
(Seq("a"), Seq("x", "y", "z")),
(Seq("a", null), Seq("x", "y")),
(Seq.empty[String], Seq.empty[String]),
(Seq("a", "b", "c"), null)
).toDF("val1", "val2")
val expectedValue1 = Seq(
Row(Seq(Row("x", "a"), Row("y", null), Row("z", null))),
Row(Seq(Row("x", "a"), Row("y", null))),
Row(Seq.empty),
Row(null))
checkAnswer(df.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue1)
}

test("arrays zip_with function - invalid") {
val df = Seq(
(Seq("c", "a", "b"), Seq("x", "y", "z"), 1),
(Seq("b", null, "c", null), Seq("x"), 2),
(Seq.empty, Seq("x", "z"), 3),
(null, Seq("x", "z"), 4)
).toDF("a1", "a2", "i")
val ex1 = intercept[AnalysisException] {
df.selectExpr("zip_with(a1, a2, x -> x)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))
val ex2 = intercept[AnalysisException] {
df.selectExpr("zip_with(a1, a2, (acc, x) -> x, (acc, x) -> x)")
}
assert(ex2.getMessage.contains("Invalid number of arguments for function zip_with"))
val ex3 = intercept[AnalysisException] {
df.selectExpr("zip_with(i, a2, (acc, x) -> x)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires array type"))
val ex4 = intercept[AnalysisException] {
df.selectExpr("zip_with(a1, a, (acc, x) -> x)")
}
assert(ex4.getMessage.contains("cannot resolve '`a`'"))
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down