Skip to content

Commit 6bf07d8

Browse files
committed
Do not prematurely bind lambda variables
1 parent 03d602f commit 6bf07d8

File tree

4 files changed

+268
-255
lines changed

4 files changed

+268
-255
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 0 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -28,139 +28,6 @@ import org.apache.spark.sql.catalyst.util._
2828
import org.apache.spark.sql.types._
2929
import org.apache.spark.unsafe.array.ByteArrayMethods
3030

31-
/**
32-
* Helper methods for constructing higher order functions.
33-
*/
34-
object HigherOrderUtils {
35-
def createLambda(
36-
dt: DataType,
37-
nullable: Boolean,
38-
f: Expression => Expression): Expression = {
39-
val lv = NamedLambdaVariable("arg", dt, nullable)
40-
val function = f(lv)
41-
LambdaFunction(function, Seq(lv))
42-
}
43-
44-
def createLambda(
45-
dt1: DataType,
46-
nullable1: Boolean,
47-
dt2: DataType,
48-
nullable2: Boolean,
49-
f: (Expression, Expression) => Expression): Expression = {
50-
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
51-
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
52-
val function = f(lv1, lv2)
53-
LambdaFunction(function, Seq(lv1, lv2))
54-
}
55-
56-
def createLambda(
57-
dt1: DataType,
58-
nullable1: Boolean,
59-
dt2: DataType,
60-
nullable2: Boolean,
61-
dt3: DataType,
62-
nullable3: Boolean,
63-
f: (Expression, Expression, Expression) => Expression): Expression = {
64-
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
65-
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
66-
val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
67-
val function = f(lv1, lv2, lv3)
68-
LambdaFunction(function, Seq(lv1, lv2, lv3))
69-
}
70-
71-
def validateBinding(
72-
e: Expression,
73-
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
74-
case f: LambdaFunction =>
75-
assert(f.arguments.size == argInfo.size)
76-
f.arguments.zip(argInfo).foreach {
77-
case (arg, (dataType, nullable)) =>
78-
assert(arg.dataType == dataType)
79-
assert(arg.nullable == nullable)
80-
}
81-
f
82-
}
83-
84-
// Array-based helpers
85-
def filter(expr: Expression, f: Expression => Expression): Expression = {
86-
val ArrayType(et, cn) = expr.dataType
87-
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
88-
}
89-
90-
def exists(expr: Expression, f: Expression => Expression): Expression = {
91-
val ArrayType(et, cn) = expr.dataType
92-
ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding)
93-
}
94-
95-
def transform(expr: Expression, f: Expression => Expression): Expression = {
96-
val ArrayType(et, cn) = expr.dataType
97-
ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding)
98-
}
99-
100-
def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
101-
val ArrayType(et, cn) = expr.dataType
102-
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
103-
}
104-
105-
def aggregate(
106-
expr: Expression,
107-
zero: Expression,
108-
merge: (Expression, Expression) => Expression,
109-
finish: Expression => Expression): Expression = {
110-
val ArrayType(et, cn) = expr.dataType
111-
val zeroType = zero.dataType
112-
ArrayAggregate(
113-
expr,
114-
zero,
115-
createLambda(zeroType, true, et, cn, merge),
116-
createLambda(zeroType, true, finish))
117-
.bind(validateBinding)
118-
}
119-
120-
def aggregate(
121-
expr: Expression,
122-
zero: Expression,
123-
merge: (Expression, Expression) => Expression): Expression = {
124-
aggregate(expr, zero, merge, identity)
125-
}
126-
127-
def zip_with(
128-
left: Expression,
129-
right: Expression,
130-
f: (Expression, Expression) => Expression): Expression = {
131-
val ArrayType(leftT, _) = left.dataType
132-
val ArrayType(rightT, _) = right.dataType
133-
ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding)
134-
}
135-
136-
// Map-based helpers
137-
138-
def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
139-
val MapType(kt, vt, vcn) = expr.dataType
140-
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
141-
}
142-
143-
def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
144-
val MapType(kt, vt, vcn) = expr.dataType
145-
TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
146-
}
147-
148-
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
149-
val MapType(kt, vt, vcn) = expr.dataType
150-
MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
151-
}
152-
153-
def map_zip_with(
154-
left: Expression,
155-
right: Expression,
156-
f: (Expression, Expression, Expression) => Expression): Expression = {
157-
val MapType(kt, vt1, _) = left.dataType
158-
val MapType(_, vt2, _) = right.dataType
159-
MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f))
160-
.bind(validateBinding)
161-
}
162-
}
163-
16431
/**
16532
* A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
16633
*/

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,102 @@ import org.apache.spark.sql.types._
2424

2525
class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2626
import org.apache.spark.sql.catalyst.dsl.expressions._
27-
import org.apache.spark.sql.catalyst.expressions.HigherOrderUtils._
27+
28+
private def createLambda(
29+
dt: DataType,
30+
nullable: Boolean,
31+
f: Expression => Expression): Expression = {
32+
val lv = NamedLambdaVariable("arg", dt, nullable)
33+
val function = f(lv)
34+
LambdaFunction(function, Seq(lv))
35+
}
36+
37+
private def createLambda(
38+
dt1: DataType,
39+
nullable1: Boolean,
40+
dt2: DataType,
41+
nullable2: Boolean,
42+
f: (Expression, Expression) => Expression): Expression = {
43+
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
44+
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
45+
val function = f(lv1, lv2)
46+
LambdaFunction(function, Seq(lv1, lv2))
47+
}
48+
49+
private def createLambda(
50+
dt1: DataType,
51+
nullable1: Boolean,
52+
dt2: DataType,
53+
nullable2: Boolean,
54+
dt3: DataType,
55+
nullable3: Boolean,
56+
f: (Expression, Expression, Expression) => Expression): Expression = {
57+
val lv1 = NamedLambdaVariable("arg1", dt1, nullable1)
58+
val lv2 = NamedLambdaVariable("arg2", dt2, nullable2)
59+
val lv3 = NamedLambdaVariable("arg3", dt3, nullable3)
60+
val function = f(lv1, lv2, lv3)
61+
LambdaFunction(function, Seq(lv1, lv2, lv3))
62+
}
63+
64+
private def validateBinding(
65+
e: Expression,
66+
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
67+
case f: LambdaFunction =>
68+
assert(f.arguments.size === argInfo.size)
69+
f.arguments.zip(argInfo).foreach {
70+
case (arg, (dataType, nullable)) =>
71+
assert(arg.dataType === dataType)
72+
assert(arg.nullable === nullable)
73+
}
74+
f
75+
}
76+
77+
def transform(expr: Expression, f: Expression => Expression): Expression = {
78+
val ArrayType(et, cn) = expr.dataType
79+
ArrayTransform(expr, createLambda(et, cn, f)).bind(validateBinding)
80+
}
81+
82+
def transform(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
83+
val ArrayType(et, cn) = expr.dataType
84+
ArrayTransform(expr, createLambda(et, cn, IntegerType, false, f)).bind(validateBinding)
85+
}
86+
87+
def filter(expr: Expression, f: Expression => Expression): Expression = {
88+
val ArrayType(et, cn) = expr.dataType
89+
ArrayFilter(expr, createLambda(et, cn, f)).bind(validateBinding)
90+
}
91+
92+
def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
93+
val MapType(kt, vt, vcn) = expr.dataType
94+
TransformKeys(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
95+
}
96+
97+
def aggregate(
98+
expr: Expression,
99+
zero: Expression,
100+
merge: (Expression, Expression) => Expression,
101+
finish: Expression => Expression): Expression = {
102+
val ArrayType(et, cn) = expr.dataType
103+
val zeroType = zero.dataType
104+
ArrayAggregate(
105+
expr,
106+
zero,
107+
createLambda(zeroType, true, et, cn, merge),
108+
createLambda(zeroType, true, finish))
109+
.bind(validateBinding)
110+
}
111+
112+
def aggregate(
113+
expr: Expression,
114+
zero: Expression,
115+
merge: (Expression, Expression) => Expression): Expression = {
116+
aggregate(expr, zero, merge, identity)
117+
}
118+
119+
def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
120+
val MapType(kt, vt, vcn) = expr.dataType
121+
TransformValues(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
122+
}
28123

29124
test("ArrayTransform") {
30125
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
@@ -68,6 +163,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
68163
}
69164

70165
test("MapFilter") {
166+
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
167+
val MapType(kt, vt, vcn) = expr.dataType
168+
MapFilter(expr, createLambda(kt, false, vt, vcn, f)).bind(validateBinding)
169+
}
71170
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
72171
MapType(IntegerType, IntegerType, valueContainsNull = false))
73172
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
@@ -145,6 +244,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
145244
}
146245

147246
test("ArrayExists") {
247+
def exists(expr: Expression, f: Expression => Expression): Expression = {
248+
val ArrayType(et, cn) = expr.dataType
249+
ArrayExists(expr, createLambda(et, cn, f)).bind(validateBinding)
250+
}
251+
148252
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
149253
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
150254
val ain = Literal.create(null, ArrayType(IntegerType, containsNull = false))
@@ -353,6 +457,16 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
353457
}
354458

355459
test("MapZipWith") {
460+
def map_zip_with(
461+
left: Expression,
462+
right: Expression,
463+
f: (Expression, Expression, Expression) => Expression): Expression = {
464+
val MapType(kt, vt1, _) = left.dataType
465+
val MapType(_, vt2, _) = right.dataType
466+
MapZipWith(left, right, createLambda(kt, false, vt1, true, vt2, true, f))
467+
.bind(validateBinding)
468+
}
469+
356470
val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30),
357471
MapType(IntegerType, IntegerType, valueContainsNull = false))
358472
val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4),
@@ -435,6 +549,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
435549
}
436550

437551
test("ZipWith") {
552+
def zip_with(
553+
left: Expression,
554+
right: Expression,
555+
f: (Expression, Expression) => Expression): Expression = {
556+
val ArrayType(leftT, _) = left.dataType
557+
val ArrayType(rightT, _) = right.dataType
558+
ZipWith(left, right, createLambda(leftT, true, rightT, true, f)).bind(validateBinding)
559+
}
560+
438561
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
439562
val ai1 = Literal.create(Seq(1, 2, 3, 4), ArrayType(IntegerType, containsNull = false))
440563
val ai2 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))

0 commit comments

Comments
 (0)