@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects
1919
2020import java .lang .reflect .Modifier
2121
22+ import scala .collection .JavaConverters ._
2223import scala .collection .mutable .Builder
2324import scala .language .existentials
2425import scala .reflect .ClassTag
@@ -501,12 +502,22 @@ case class LambdaVariable(
501502 value : String ,
502503 isNull : String ,
503504 dataType : DataType ,
504- nullable : Boolean = true ) extends LeafExpression
505- with Unevaluable with NonSQLExpression {
505+ nullable : Boolean = true ) extends LeafExpression with NonSQLExpression {
506+
507+ // Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
508+ override def eval (input : InternalRow ): Any = {
509+ assert(input.numFields == 1 ,
510+ " The input row of interpreted LambdaVariable should have only 1 field." )
511+ input.get(0 , dataType)
512+ }
506513
507514 override def genCode (ctx : CodegenContext ): ExprCode = {
508515 ExprCode (code = " " , value = value, isNull = if (nullable) isNull else " false" )
509516 }
517+
518+ // This won't be called as `genCode` is overrided, just overriding it to make
519+ // `LambdaVariable` non-abstract.
520+ override protected def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = ev
510521}
511522
512523/**
@@ -599,8 +610,92 @@ case class MapObjects private(
599610
600611 override def children : Seq [Expression ] = lambdaFunction :: inputData :: Nil
601612
602- override def eval (input : InternalRow ): Any =
603- throw new UnsupportedOperationException (" Only code-generated evaluation is supported" )
613+ // The data with UserDefinedType are actually stored with the data type of its sqlType.
614+ // When we want to apply MapObjects on it, we have to use it.
615+ lazy private val inputDataType = inputData.dataType match {
616+ case u : UserDefinedType [_] => u.sqlType
617+ case _ => inputData.dataType
618+ }
619+
620+ private def executeFuncOnCollection (inputCollection : Seq [_]): Iterator [_] = {
621+ val row = new GenericInternalRow (1 )
622+ inputCollection.toIterator.map { element =>
623+ row.update(0 , element)
624+ lambdaFunction.eval(row)
625+ }
626+ }
627+
628+ private lazy val convertToSeq : Any => Seq [_] = inputDataType match {
629+ case ObjectType (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
630+ _.asInstanceOf [Seq [_]]
631+ case ObjectType (cls) if cls.isArray =>
632+ _.asInstanceOf [Array [_]].toSeq
633+ case ObjectType (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
634+ _.asInstanceOf [java.util.List [_]].asScala
635+ case ObjectType (cls) if cls == classOf [Object ] =>
636+ (inputCollection) => {
637+ if (inputCollection.getClass.isArray) {
638+ inputCollection.asInstanceOf [Array [_]].toSeq
639+ } else {
640+ inputCollection.asInstanceOf [Seq [_]]
641+ }
642+ }
643+ case ArrayType (et, _) =>
644+ _.asInstanceOf [ArrayData ].array
645+ }
646+
647+ private lazy val mapElements : Seq [_] => Any = customCollectionCls match {
648+ case Some (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
649+ // Scala sequence
650+ executeFuncOnCollection(_).toSeq
651+ case Some (cls) if classOf [scala.collection.Set [_]].isAssignableFrom(cls) =>
652+ // Scala set
653+ executeFuncOnCollection(_).toSet
654+ case Some (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
655+ // Java list
656+ if (cls == classOf [java.util.List [_]] || cls == classOf [java.util.AbstractList [_]] ||
657+ cls == classOf [java.util.AbstractSequentialList [_]]) {
658+ // Specifying non concrete implementations of `java.util.List`
659+ executeFuncOnCollection(_).toSeq.asJava
660+ } else {
661+ val constructors = cls.getConstructors()
662+ val intParamConstructor = constructors.find { constructor =>
663+ constructor.getParameterCount == 1 && constructor.getParameterTypes()(0 ) == classOf [Int ]
664+ }
665+ val noParamConstructor = constructors.find { constructor =>
666+ constructor.getParameterCount == 0
667+ }
668+
669+ val constructor = intParamConstructor.map { intConstructor =>
670+ (len : Int ) => intConstructor.newInstance(len.asInstanceOf [Object ])
671+ }.getOrElse {
672+ (_ : Int ) => noParamConstructor.get.newInstance()
673+ }
674+
675+ // Specifying concrete implementations of `java.util.List`
676+ (inputs) => {
677+ val results = executeFuncOnCollection(inputs)
678+ val builder = constructor(inputs.length).asInstanceOf [java.util.List [Any ]]
679+ results.foreach(builder.add(_))
680+ builder
681+ }
682+ }
683+ case None =>
684+ // array
685+ x => new GenericArrayData (executeFuncOnCollection(x).toArray)
686+ case Some (cls) =>
687+ throw new RuntimeException (s " class ` ${cls.getName}` is not supported by `MapObjects` as " +
688+ " resulting collection." )
689+ }
690+
691+ override def eval (input : InternalRow ): Any = {
692+ val inputCollection = inputData.eval(input)
693+
694+ if (inputCollection == null ) {
695+ return null
696+ }
697+ mapElements(convertToSeq(inputCollection))
698+ }
604699
605700 override def dataType : DataType =
606701 customCollectionCls.map(ObjectType .apply).getOrElse(
@@ -647,13 +742,6 @@ case class MapObjects private(
647742 case _ => " "
648743 }
649744
650- // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
651- // When we want to apply MapObjects on it, we have to use it.
652- val inputDataType = inputData.dataType match {
653- case p : PythonUserDefinedType => p.sqlType
654- case _ => inputData.dataType
655- }
656-
657745 // `MapObjects` generates a while loop to traverse the elements of the input collection. We
658746 // need to take care of Seq and List because they may have O(n) complexity for indexed accessing
659747 // like `list.get(1)`. Here we use Iterator to traverse Seq and List.
0 commit comments