@@ -24,7 +24,102 @@ import org.apache.spark.sql.types._
2424
2525class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
2626 import org .apache .spark .sql .catalyst .dsl .expressions ._
27- import org .apache .spark .sql .catalyst .expressions .HigherOrderUtils ._
27+
28+ private def createLambda (
29+ dt : DataType ,
30+ nullable : Boolean ,
31+ f : Expression => Expression ): Expression = {
32+ val lv = NamedLambdaVariable (" arg" , dt, nullable)
33+ val function = f(lv)
34+ LambdaFunction (function, Seq (lv))
35+ }
36+
37+ private def createLambda (
38+ dt1 : DataType ,
39+ nullable1 : Boolean ,
40+ dt2 : DataType ,
41+ nullable2 : Boolean ,
42+ f : (Expression , Expression ) => Expression ): Expression = {
43+ val lv1 = NamedLambdaVariable (" arg1" , dt1, nullable1)
44+ val lv2 = NamedLambdaVariable (" arg2" , dt2, nullable2)
45+ val function = f(lv1, lv2)
46+ LambdaFunction (function, Seq (lv1, lv2))
47+ }
48+
49+ private def createLambda (
50+ dt1 : DataType ,
51+ nullable1 : Boolean ,
52+ dt2 : DataType ,
53+ nullable2 : Boolean ,
54+ dt3 : DataType ,
55+ nullable3 : Boolean ,
56+ f : (Expression , Expression , Expression ) => Expression ): Expression = {
57+ val lv1 = NamedLambdaVariable (" arg1" , dt1, nullable1)
58+ val lv2 = NamedLambdaVariable (" arg2" , dt2, nullable2)
59+ val lv3 = NamedLambdaVariable (" arg3" , dt3, nullable3)
60+ val function = f(lv1, lv2, lv3)
61+ LambdaFunction (function, Seq (lv1, lv2, lv3))
62+ }
63+
64+ private def validateBinding (
65+ e : Expression ,
66+ argInfo : Seq [(DataType , Boolean )]): LambdaFunction = e match {
67+ case f : LambdaFunction =>
68+ assert(f.arguments.size === argInfo.size)
69+ f.arguments.zip(argInfo).foreach {
70+ case (arg, (dataType, nullable)) =>
71+ assert(arg.dataType === dataType)
72+ assert(arg.nullable === nullable)
73+ }
74+ f
75+ }
76+
77+ def transform (expr : Expression , f : Expression => Expression ): Expression = {
78+ val ArrayType (et, cn) = expr.dataType
79+ ArrayTransform (expr, createLambda(et, cn, f)).bind(validateBinding)
80+ }
81+
82+ def transform (expr : Expression , f : (Expression , Expression ) => Expression ): Expression = {
83+ val ArrayType (et, cn) = expr.dataType
84+ ArrayTransform (expr, createLambda(et, cn, IntegerType , false , f)).bind(validateBinding)
85+ }
86+
87+ def filter (expr : Expression , f : Expression => Expression ): Expression = {
88+ val ArrayType (et, cn) = expr.dataType
89+ ArrayFilter (expr, createLambda(et, cn, f)).bind(validateBinding)
90+ }
91+
92+ def transformKeys (expr : Expression , f : (Expression , Expression ) => Expression ): Expression = {
93+ val MapType (kt, vt, vcn) = expr.dataType
94+ TransformKeys (expr, createLambda(kt, false , vt, vcn, f)).bind(validateBinding)
95+ }
96+
97+ def aggregate (
98+ expr : Expression ,
99+ zero : Expression ,
100+ merge : (Expression , Expression ) => Expression ,
101+ finish : Expression => Expression ): Expression = {
102+ val ArrayType (et, cn) = expr.dataType
103+ val zeroType = zero.dataType
104+ ArrayAggregate (
105+ expr,
106+ zero,
107+ createLambda(zeroType, true , et, cn, merge),
108+ createLambda(zeroType, true , finish))
109+ .bind(validateBinding)
110+ }
111+
112+ def aggregate (
113+ expr : Expression ,
114+ zero : Expression ,
115+ merge : (Expression , Expression ) => Expression ): Expression = {
116+ aggregate(expr, zero, merge, identity)
117+ }
118+
119+ def transformValues (expr : Expression , f : (Expression , Expression ) => Expression ): Expression = {
120+ val MapType (kt, vt, vcn) = expr.dataType
121+ TransformValues (expr, createLambda(kt, false , vt, vcn, f)).bind(validateBinding)
122+ }
28123
29124 test(" ArrayTransform" ) {
30125 val ai0 = Literal .create(Seq (1 , 2 , 3 ), ArrayType (IntegerType , containsNull = false ))
@@ -68,6 +163,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
68163 }
69164
70165 test(" MapFilter" ) {
166+ def mapFilter (expr : Expression , f : (Expression , Expression ) => Expression ): Expression = {
167+ val MapType (kt, vt, vcn) = expr.dataType
168+ MapFilter (expr, createLambda(kt, false , vt, vcn, f)).bind(validateBinding)
169+ }
71170 val mii0 = Literal .create(Map (1 -> 0 , 2 -> 10 , 3 -> - 1 ),
72171 MapType (IntegerType , IntegerType , valueContainsNull = false ))
73172 val mii1 = Literal .create(Map (1 -> null , 2 -> 10 , 3 -> null ),
@@ -145,6 +244,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
145244 }
146245
147246 test(" ArrayExists" ) {
247+ def exists (expr : Expression , f : Expression => Expression ): Expression = {
248+ val ArrayType (et, cn) = expr.dataType
249+ ArrayExists (expr, createLambda(et, cn, f)).bind(validateBinding)
250+ }
251+
148252 val ai0 = Literal .create(Seq (1 , 2 , 3 ), ArrayType (IntegerType , containsNull = false ))
149253 val ai1 = Literal .create(Seq [Integer ](1 , null , 3 ), ArrayType (IntegerType , containsNull = true ))
150254 val ain = Literal .create(null , ArrayType (IntegerType , containsNull = false ))
@@ -353,6 +457,16 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
353457 }
354458
355459 test(" MapZipWith" ) {
460+ def map_zip_with (
461+ left : Expression ,
462+ right : Expression ,
463+ f : (Expression , Expression , Expression ) => Expression ): Expression = {
464+ val MapType (kt, vt1, _) = left.dataType
465+ val MapType (_, vt2, _) = right.dataType
466+ MapZipWith (left, right, createLambda(kt, false , vt1, true , vt2, true , f))
467+ .bind(validateBinding)
468+ }
469+
356470 val mii0 = Literal .create(create_map(1 -> 10 , 2 -> 20 , 3 -> 30 ),
357471 MapType (IntegerType , IntegerType , valueContainsNull = false ))
358472 val mii1 = Literal .create(create_map(1 -> - 1 , 2 -> - 2 , 4 -> - 4 ),
@@ -435,6 +549,15 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
435549 }
436550
437551 test(" ZipWith" ) {
552+ def zip_with (
553+ left : Expression ,
554+ right : Expression ,
555+ f : (Expression , Expression ) => Expression ): Expression = {
556+ val ArrayType (leftT, _) = left.dataType
557+ val ArrayType (rightT, _) = right.dataType
558+ ZipWith (left, right, createLambda(leftT, true , rightT, true , f)).bind(validateBinding)
559+ }
560+
438561 val ai0 = Literal .create(Seq (1 , 2 , 3 ), ArrayType (IntegerType , containsNull = false ))
439562 val ai1 = Literal .create(Seq (1 , 2 , 3 , 4 ), ArrayType (IntegerType , containsNull = false ))
440563 val ai2 = Literal .create(Seq [Integer ](1 , null , 3 ), ArrayType (IntegerType , containsNull = true ))
0 commit comments