-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28962][SQL] Provide index argument to filter lambda functions #25666
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -344,8 +344,13 @@ case class MapFilter( | |
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), x -> x % 2 == 1); | ||
| [1,3] | ||
| > SELECT _FUNC_(array(0, 2, 3), (x, i) -> x > i); | ||
| [2,3] | ||
| """, | ||
| since = "2.4.0") | ||
| since = "2.4.0", | ||
| note = """ | ||
| The inner function may use the index argument since 3.0.0. | ||
| """) | ||
| case class ArrayFilter( | ||
| argument: Expression, | ||
| function: Expression) | ||
|
|
@@ -357,10 +362,19 @@ case class ArrayFilter( | |
|
|
||
| override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = { | ||
| val ArrayType(elementType, containsNull) = argument.dataType | ||
| copy(function = f(function, (elementType, containsNull) :: Nil)) | ||
| function match { | ||
| case LambdaFunction(_, arguments, _) if arguments.size == 2 => | ||
| copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil)) | ||
| case _ => | ||
| copy(function = f(function, (elementType, containsNull) :: Nil)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to validate # of arguments here? (the case: arguments.size > 2)
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you check the current error mesasage for the case?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ArrayTransform doesn't validate arguments.size > 2. I'm not sure what happens in that case either.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nvm. I checked the error handling works well for the case.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it does. See the test here: https://github.com/apache/spark/pull/25666/files#diff-8e1a34391fdefa4a3a0349d7d454d86fR2204. |
||
| } | ||
| } | ||
|
|
||
| @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function | ||
| @transient lazy val (elementVar, indexVar) = { | ||
| val LambdaFunction(_, (elementVar: NamedLambdaVariable) +: tail, _) = function | ||
| val indexVar = tail.headOption.map(_.asInstanceOf[NamedLambdaVariable]) | ||
| (elementVar, indexVar) | ||
| } | ||
|
|
||
| override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { | ||
| val arr = argumentValue.asInstanceOf[ArrayData] | ||
|
|
@@ -369,6 +383,9 @@ case class ArrayFilter( | |
| var i = 0 | ||
| while (i < arr.numElements) { | ||
| elementVar.value.set(arr.get(i, elementVar.dataType)) | ||
| if (indexVar.isDefined) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you avoid this per-row check? The current code causes unnecessary runtime overheads.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu do you have a suggestion about how to do this without implementing codegen? I tried rewriting the logic like so: @transient private lazy val evalFn: (InternalRow, Any) => Any = indexVar match {
case None => (inputRow, argumentValue) =>
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
case Some(expr) => (inputRow, argumentValue) =>
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
while (i < arr.numElements) {
elementVar.value.set(arr.get(i, elementVar.dataType))
expr.value.set(i)
if (f.eval(inputRow).asInstanceOf[Boolean]) {
buffer += elementVar.value.get
}
i += 1
}
new GenericArrayData(buffer)
}
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
evalFn(inputRow, argumentValue)
}But from some hacky microbenchmarking this doesn't seem to be meaningfully faster and if anything is marginally slower.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the benchmark code I was using: test("ArrayFilter - benchmark") {
import scala.concurrent.duration._
val b = new Benchmark(
"array_filter",
1000,
warmupTime = 5.seconds,
minTime = 5.seconds)
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val isEven: Expression => Expression = x => x % 2 === 0
b.addCase("filter") { _ =>
var i = 0
while (i < 1000) {
filter(ai0, isEven).eval()
i += 1
}
}
b.run()
}
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maropu @henrydavidge The best performing way to avoid the per-row check in a non-codegen setting is to introduce a new expression type, say The tradeoff between the inline per-row check and the lambda batch solution is that on input arrays that are small (like the one @henrydavidge used in his benchmark), the lambda invocation (which is not guaranteed to be inlined+optimized) overhead may exceed the per-row check overhead. You'd need a fairly large input array to amortize that. If we want to make it stay simple for now, I'm okay with the inline per-row check version.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought code like this;
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, tried that as well. It doesn't seem to be significantly different from the others.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, if no big difference, I like the similar handling with the others, e.g., https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala#L555
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is good enough to go.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 for this is ready to go for now and we can address the optimization separately. Side-comment on the version that @maropu gave: whereas the version that @maropu gave (i.e. "element-wise lambda") would be: With @maropu 's version, let's assume that we're running on the HotSpot JVM and both the with-index and without-index paths have been used, then the best the HotSpot JIT compiler could have done is a profile-guided bimorphic devirtualization on that lambda call site, which will look like the following after devirtualization+inlining: The point is that this JIT-optimized version is actually a degenerated version of Henry's hand-written inline per-element check version, so I wouldn't want to go down this route.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, kris! That explanation's very helpful to me. |
||
| indexVar.get.value.set(i) | ||
| } | ||
| if (f.eval(inputRow).asInstanceOf[Boolean]) { | ||
| buffer += elementVar.value.get | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, the indices start at 0. but it sounds like the other built-in functions start at 1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember there was the (not-merged) PR to standardize one-based column indexes in built-in funcs: #24051
Better to fix them up for consistency?