Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ object FunctionRegistry {
expression[ArrayRemove]("array_remove"),
expression[ArrayDistinct]("array_distinct"),
expression[ArrayTransform]("transform"),
expression[ArrayFilter]("filter"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.concurrent.atomic.AtomicReference

import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
Expand Down Expand Up @@ -140,6 +142,18 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu
@transient lazy val functionForEval: Expression = functionsForEval.head
}

object ArrayBasedHigherOrderFunction {

def elementArgumentType(dt: DataType): (DataType, Boolean) = {
dt match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
}
}

/**
* Transform elements in an array using the transform function. This is similar to
* a `map` in functional programming.
Expand All @@ -164,17 +178,12 @@ case class ArrayTransform(
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
val (elementType, containsNull) = input.dataType match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
copy(function = f(function, elem :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
copy(function = f(function, elem :: Nil))
}
}

Expand Down Expand Up @@ -210,3 +219,54 @@ case class ArrayTransform(

override def prettyName: String = "transform"
}

/**
* Filters the input array using the given lambda function.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Filters the input array using the given predicate.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1);
array(1, 3)
""",
since = "2.4.0")
case class ArrayFilter(
input: Expression,
function: Expression)
extends ArrayBasedHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = input.nullable

override def dataType: DataType = input.dataType

override def expectingFunctionType: AbstractDataType = BooleanType

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
copy(function = f(function, elem :: Nil))
}

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function

override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(input).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
}
}

override def prettyName: String = "filter"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is filter too generic? wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it might be. How about array_filter?

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayTransform(expr, createLambda(at.elementType, at.containsNull, IntegerType, false, f))
}

def filter(expr: Expression, f: Expression => Expression): Expression = {
val at = expr.dataType.asInstanceOf[ArrayType]
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
}

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down Expand Up @@ -94,4 +99,36 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)),
Seq("[1, 3, 5]", null, "[4, 6]"))
}

test("ArrayFilter") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))

val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1

checkEvaluation(filter(ai0, isEven), Seq(2))
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
checkEvaluation(filter(ai1, isEven), Seq.empty)
checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3))
checkEvaluation(filter(ain, isEven), null)
checkEvaluation(filter(ain, isNullOrOdd), null)

val as0 =
Literal.create(Seq("a0", "b1", "a2", "c3"), ArrayType(StringType, containsNull = false))
val as1 = Literal.create(Seq("a", null, "c"), ArrayType(StringType, containsNull = true))
val asn = Literal.create(null, ArrayType(StringType, containsNull = false))

val startsWithA: Expression => Expression = x => x.startsWith("a")

checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2"))
checkEvaluation(filter(as1, startsWithA), Seq("a"))
checkEvaluation(filter(asn, startsWithA), null)

val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
Seq(Seq(1, 3), null, Seq(5)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,12 @@ select transform(ys, 0) as v from nested;

-- Transform a null array
select transform(cast(null as array<int>), x -> x + 1) as v;

-- Filter.
select filter(ys, y -> y > 30) as v from nested;

-- Filter a null array
select filter(cast(null as array<int>), y -> true) as v;

-- Filter nested arrays
select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 8
-- Number of queries: 11


-- !query 0
Expand Down Expand Up @@ -79,3 +79,31 @@ select transform(cast(null as array<int>), x -> x + 1) as v
struct<v:array<int>>
-- !query 7 output
NULL


-- !query 8
select filter(ys, y -> y > 30) as v from nested
-- !query 8 schema
struct<v:array<int>>
-- !query 8 output
[32,97]
[77]
[]


-- !query 9
select filter(cast(null as array<int>), y -> true) as v
-- !query 9 schema
struct<v:array<int>>
-- !query 9 output
NULL


-- !query 10
select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested
-- !query 10 schema
struct<v:array<array<int>>>
-- !query 10 output
[[96,65],[]]
[[99],[123],[]]
[[]]
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,102 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
}

test("filter function - array for primitive type not containing null") {
val df = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"),
Seq(
Row(Seq(8)),
Row(Seq(8, 2)),
Row(Seq.empty),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeNotContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testArrayOfPrimitiveTypeNotContainsNull()
}

test("filter function - array for primitive type containing null") {
val df = Seq[Seq[Integer]](
Seq(1, 9, 8, null, 7),
Seq(5, null, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

def testArrayOfPrimitiveTypeContainsNull(): Unit = {
checkAnswer(df.selectExpr("filter(i, x -> x % 2 == 0)"),
Seq(
Row(Seq(8)),
Row(Seq(8, 2)),
Row(Seq.empty),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testArrayOfPrimitiveTypeContainsNull()
}

test("filter function - array for non-primitive type") {
val df = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

def testNonPrimitiveType(): Unit = {
checkAnswer(df.selectExpr("filter(s, x -> x is not null)"),
Seq(
Row(Seq("c", "a", "b")),
Row(Seq("b", "c")),
Row(Seq.empty),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testNonPrimitiveType()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testNonPrimitiveType()
}

test("filter function - invalid") {
val df = Seq(
(Seq("c", "a", "b"), 1),
(Seq("b", null, "c", null), 2),
(Seq.empty, 3),
(null, 4)
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("filter(s, (x, y) -> x + y)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("filter(i, x -> x)")
}
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))

val ex3 = intercept[AnalysisException] {
df.selectExpr("filter(s, x -> x)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down