Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,13 @@ case class MapFilter(
Examples:
> SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1);
[1,3]
> SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, the indices start at 0. but it sounds like the other built-in functions start at 1.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember there was the (not-merged) PR to standardize one-based column indexes in built-in funcs: #24051
Better to fix them up for consistency?

[2,3]
""",
since = "2.4.0")
since = "2.4.0",
note = """
The inner function may use the index argument since 3.0.0.
""")
case class ArrayFilter(
argument: Expression,
function: Expression)
Expand All @@ -357,10 +362,19 @@ case class ArrayFilter(

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, (elementType, containsNull) :: Nil))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to validate # of arguments here? (the case: arguments.size > 2)

Copy link
Member

@maropu maropu Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check the current error mesasage for the case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArrayTransform doesn't validate arguments.size > 2. I'm not sure what happens in that case either.

Copy link
Member

@maropu maropu Sep 20, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm. I checked the error handling works well for the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}

@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
@transient lazy val (elementVar, indexVar) = {
val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function
val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable])
(elementVar, indexVar)
}

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
Expand All @@ -369,6 +383,9 @@ case class ArrayFilter(
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (indexVar.isDefined) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you avoid this per-row check? The current code causes unnecessary runtime overheads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu do you have a suggestion about how to do this without implementing codegen? I tried rewriting the logic like so:

  @transient private lazy val evalFn: (InternalRow, Any) => Any = indexVar match {
    case None => (inputRow, argumentValue) =>
      val arr = argumentValue.asInstanceOf[ArrayData]
      val f = functionForEval
      val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
      var i = 0
      while (i < arr.numElements) {
        elementVar.value.set(arr.get(i, elementVar.dataType))
        if (f.eval(inputRow).asInstanceOf[Boolean]) {
          buffer += elementVar.value.get
        }
        i += 1
      }
      new GenericArrayData(buffer)

    case Some(expr) => (inputRow, argumentValue) =>
      val arr = argumentValue.asInstanceOf[ArrayData]
      val f = functionForEval
      val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
      var i = 0
      while (i < arr.numElements) {
        elementVar.value.set(arr.get(i, elementVar.dataType))
        expr.value.set(i)
        if (f.eval(inputRow).asInstanceOf[Boolean]) {
          buffer += elementVar.value.get
        }
        i += 1
      }
      new GenericArrayData(buffer)

  }


  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
    evalFn(inputRow, argumentValue)
  }

But from some hacky microbenchmarking this doesn't seem to be meaningfully faster and if anything is marginally slower.

Copy link
Contributor Author

@henrydavidge henrydavidge Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the benchmark code I was using:

  test("ArrayFilter - benchmark") {
    import scala.concurrent.duration._
    val b = new Benchmark(
      "array_filter",
      1000,
      warmupTime = 5.seconds,
      minTime = 5.seconds)
    val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
    val isEven: Expression => Expression = x => x % 2 === 0
    b.addCase("filter") { _ =>
      var i = 0
      while (i < 1000) {
        filter(ai0, isEven).eval()
        i += 1
      }
    }
    b.run()
  }

Copy link
Contributor

@rednaxelafx rednaxelafx Oct 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maropu @henrydavidge The best performing way to avoid the per-row check in a non-codegen setting is to introduce a new expression type, say ArrayFilterWithIndex.

The tradeoff between the inline per-row check and the lambda batch solution is that on input arrays that are small (like the one @henrydavidge used in his benchmark), the lambda invocation (which is not guaranteed to be inlined+optimized) overhead may exceed the per-row check overhead. You'd need a fairly large input array to amortize that.

If we want to make it stay simple for now, I'm okay with the inline per-row check version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought code like this;


  @transient lazy val (elementVar, mayFillIndex) = function match {
    case LambdaFunction(_, Seq(elemVar: NamedLambdaVariable), _) =>
      (elemVar, (_: Int) => {})
    case LambdaFunction(_, Seq(elemVar: NamedLambdaVariable, idxVar: NamedLambdaVariable), _) =>
      (elemVar, (i: Int) => idxVar.value.set(i))
  }

  override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
    val arr = argumentValue.asInstanceOf[ArrayData]
    val f = functionForEval
    val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
    var i = 0
    while (i < arr.numElements) {
      elementVar.value.set(arr.get(i, elementVar.dataType))
      mayFillIndex(i)
      if (f.eval(inputRow).asInstanceOf[Boolean]) {
        buffer += elementVar.value.get
      }
      i += 1
    }
    new GenericArrayData(buffer)
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, tried that as well. It doesn't seem to be significantly different from the others.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good enough to go.
How about merging this for now, and addressing it in a separate PR?
transform is doing the same way, so I think we should do the same thing if needed, maybe at the same time.

Copy link
Contributor

@rednaxelafx rednaxelafx Oct 2, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this is ready to go for now and we can address the optimization separately.

Side-comment on the version that @maropu gave:
The lambda version that @henrydavidge gave (i.e. "batch-wise lambda") would technically have less overhead:

// lambda invocation overhead outside of loop
for each element in array
  do specialized filter action

whereas the version that @maropu gave (i.e. "element-wise lambda") would be:

// shared loop between the two versions
for each element in array
  // lambda invocation overhead per element
  invoke mayFillIndex lambda

With @maropu 's version, let's assume that we're running on the HotSpot JVM and both the with-index and without-index paths have been used, then the best the HotSpot JIT compiler could have done is a profile-guided bimorphic devirtualization on that lambda call site, which will look like the following after devirtualization+inlining:

local_mayFillIndex = this.mayFillIndex
klazz = local_mayFillIndex.klass
for each element in array
  // ...
  if (klazz == lambda_klass_1) {
    // no-op
  } else if (klazz == lambda_klass_2) {
    idxVar.value.set(i)
  } else {
    uncommon_trap() // aka deoptimize, or potentially a full virtual call
  }
}

The point is that this JIT-optimized version is actually a degenerated version of Henry's hand-written inline per-element check version, so I wouldn't want to go down this route.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, kris! That explanation's very helpful to me.

indexVar.get.value.set(i)
}
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
}

def filter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val ArrayType(et, cn) = expr.dataType
ArrayFilter(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val MapType(kt, vt, vcn) = expr.dataType
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
Expand Down Expand Up @@ -218,9 +223,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper

val isEven: Expression => Expression = x => x % 2 === 0
val isNullOrOdd: Expression => Expression = x => x.isNull || x % 2 === 1
val indexIsEven: (Expression, Expression) => Expression = { case (_, idx) => idx % 2 === 0 }

checkEvaluation(filter(ai0, isEven), Seq(2))
checkEvaluation(filter(ai0, isNullOrOdd), Seq(1, 3))
checkEvaluation(filter(ai0, indexIsEven), Seq(1, 3))
checkEvaluation(filter(ai1, isEven), Seq.empty)
checkEvaluation(filter(ai1, isNullOrOdd), Seq(1, null, 3))
checkEvaluation(filter(ain, isEven), null)
Expand All @@ -234,13 +241,17 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
val startsWithA: Expression => Expression = x => x.startsWith("a")

checkEvaluation(filter(as0, startsWithA), Seq("a0", "a2"))
checkEvaluation(filter(as0, indexIsEven), Seq("a0", "a2"))
checkEvaluation(filter(as1, startsWithA), Seq("a"))
checkEvaluation(filter(as1, indexIsEven), Seq("a", "c"))
checkEvaluation(filter(asn, startsWithA), null)

val aai = Literal.create(Seq(Seq(1, 2, 3), null, Seq(4, 5)),
ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
checkEvaluation(transform(aai, ix => filter(ix, isNullOrOdd)),
Seq(Seq(1, 3), null, Seq(5)))
checkEvaluation(transform(aai, ix => filter(ix, indexIsEven)),
Seq(Seq(1, 3), null, Seq(4)))
}

test("ArrayExists") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2190,6 +2190,30 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
testNonPrimitiveType()
}

test("filter function - index argument") {
val df = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

def testIndexArgument(): Unit = {
checkAnswer(df.selectExpr("filter(s, (x, i) -> i % 2 == 0)"),
Seq(
Row(Seq("c", "b")),
Row(Seq("b", "c")),
Row(Seq.empty),
Row(null)))
}

// Test with local relation, the Project will be evaluated without codegen
testIndexArgument()
// Test with cached relation, the Project will be evaluated with codegen
df.cache()
testIndexArgument()
}

test("filter function - invalid") {
val df = Seq(
(Seq("c", "a", "b"), 1),
Expand All @@ -2199,9 +2223,9 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
).toDF("s", "i")

val ex1 = intercept[AnalysisException] {
df.selectExpr("filter(s, (x, y) -> x + y)")
df.selectExpr("filter(s, (x, y, z) -> x + y)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '2' does not match"))
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex2 = intercept[AnalysisException] {
df.selectExpr("filter(i, x -> x)")
Expand Down