@@ -430,24 +430,33 @@ object MapObjects {
430430 * @param function The function applied on the collection elements.
431431 * @param inputData An expression that when evaluated returns a collection object.
432432 * @param elementType The data type of elements in the collection.
433+ * @param collClass The class of the resulting collection
433434 */
434435 def apply (
435436 function : Expression => Expression ,
436437 inputData : Expression ,
437- elementType : DataType ): MapObjects = {
438- val loopValue = " MapObjects_loopValue" + curId.getAndIncrement()
439- val loopIsNull = " MapObjects_loopIsNull" + curId.getAndIncrement()
438+ elementType : DataType ,
439+ collClass : Class [_] = classOf [Array [_]]): MapObjects = {
440+ val id = curId.getAndIncrement()
441+ val loopValue = s " MapObjects_loopValue $id"
442+ val loopIsNull = s " MapObjects_loopIsNull $id"
440443 val loopVar = LambdaVariable (loopValue, loopIsNull, elementType)
441- MapObjects (loopValue, loopIsNull, elementType, function(loopVar), inputData)
444+ val builderValue = s " MapObjects_builderValue $id"
445+ MapObjects (loopValue, loopIsNull, elementType, function(loopVar), inputData,
446+ collClass, builderValue)
442447 }
443448}
444449
445450/**
446451 * Applies the given expression to every element of a collection of items, returning the result
447- * as an ArrayType. This is similar to a typical map operation, but where the lambda function
448- * is expressed using catalyst expressions.
452+ * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
453+ * function is expressed using catalyst expressions.
454+ *
455+ * The type of the result is determined as follows:
456+ * - ArrayType - when collClass is an array class
457+ * - ObjectType(collClass) - when collClass is a collection class
449458 *
450- * The following collection ObjectTypes are currently supported:
459+ * The following collection ObjectTypes are currently supported on input :
451460 * Seq, Array, ArrayData, java.util.List
452461 *
453462 * @param loopValue the name of the loop variable that used when iterate the collection, and used
@@ -459,13 +468,18 @@ object MapObjects {
459468 * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
460469 * to handle collection elements.
461470 * @param inputData An expression that when evaluated returns a collection object.
471+ * @param collClass The class of the resulting collection
472+ * @param builderValue The name of the builder variable used to construct the resulting collection
473+ * (used only when returning ObjectType)
462474 */
463475case class MapObjects private (
464476 loopValue : String ,
465477 loopIsNull : String ,
466478 loopVarDataType : DataType ,
467479 lambdaFunction : Expression ,
468- inputData : Expression ) extends Expression with NonSQLExpression {
480+ inputData : Expression ,
481+ collClass : Class [_],
482+ builderValue : String ) extends Expression with NonSQLExpression {
469483
470484 override def nullable : Boolean = inputData.nullable
471485
@@ -475,7 +489,8 @@ case class MapObjects private(
475489 throw new UnsupportedOperationException (" Only code-generated evaluation is supported" )
476490
477491 override def dataType : DataType =
478- ArrayType (lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
492+ if (! collClass.isArray) ObjectType (collClass)
493+ else ArrayType (lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
479494
480495 override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
481496 val elementJavaType = ctx.javaType(loopVarDataType)
@@ -558,169 +573,23 @@ case class MapObjects private(
558573 case _ => s " $loopIsNull = $loopValue == null; "
559574 }
560575
561- val code = s """
562- ${genInputData.code}
563- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
564-
565- if (! ${genInputData.isNull}) {
566- $determineCollectionType
567- $convertedType[] $convertedArray = null;
568- int $dataLength = $getLength;
569- $convertedArray = $arrayConstructor;
570-
571- int $loopIndex = 0;
572- while ( $loopIndex < $dataLength) {
573- $loopValue = ( $elementJavaType) ( $getLoopVar);
574- $loopNullCheck
575-
576- ${genFunction.code}
577- if ( ${genFunction.isNull}) {
578- $convertedArray[ $loopIndex] = null;
579- } else {
580- $convertedArray[ $loopIndex] = $genFunctionValue;
581- }
582-
583- $loopIndex += 1;
584- }
585-
586- ${ev.value} = new ${classOf [GenericArrayData ].getName}( $convertedArray);
576+ val (genInit, genAssign, genResult): (String , String => String , String ) =
577+ if (collClass.isArray) {
578+ // array
579+ (s """ $convertedType[] $convertedArray = null;
580+ $convertedArray = $arrayConstructor; """ ,
581+ genValue => s " $convertedArray[ $loopIndex] = $genValue; " ,
582+ s " new ${classOf [GenericArrayData ].getName}( $convertedArray); " )
583+ } else {
584+ // collection
585+ val collObjectName = s " ${collClass.getName}$$ .MODULE $$ "
586+ val getBuilderVar = s " $collObjectName.newBuilder() "
587+
588+ (s """ ${classOf [Builder [_, _]].getName} $builderValue = $getBuilderVar;
589+ $builderValue.sizeHint( $dataLength); """ ,
590+ genValue => s " $builderValue. $$ plus $$ eq( $genValue); " ,
591+ s " ( ${collClass.getName}) $builderValue.result(); " )
587592 }
588- """
589- ev.copy(code = code, isNull = genInputData.isNull)
590- }
591- }
592-
593- object CollectObjects {
594- private val curId = new java.util.concurrent.atomic.AtomicInteger ()
595-
596- /**
597- * Construct an instance of CollectObjects case class.
598- *
599- * @param function The function applied on the collection elements.
600- * @param inputData An expression that when evaluated returns a collection object.
601- * @param elementType The data type of elements in the collection.
602- * @param collClass The type of the resulting collection.
603- */
604- def apply (
605- function : Expression => Expression ,
606- inputData : Expression ,
607- elementType : DataType ,
608- collClass : Class [_]): CollectObjects = {
609- val loopValue = " CollectObjects_loopValue" + curId.getAndIncrement()
610- val loopIsNull = " CollectObjects_loopIsNull" + curId.getAndIncrement()
611- val loopVar = LambdaVariable (loopValue, loopIsNull, elementType)
612- val builderValue = " CollectObjects_builderValue" + curId.getAndIncrement()
613- CollectObjects (loopValue, loopIsNull, elementType, function(loopVar), inputData,
614- collClass, builderValue)
615- }
616- }
617-
618- /**
619- * An equivalent to the [[MapObjects ]] case class but returning an ObjectType containing
620- * a Scala collection constructed using the associated builder, obtained by calling `newBuilder`
621- * on the collection's companion object.
622- *
623- * @param loopValue the name of the loop variable that used when iterate the collection, and used
624- * as input for the `lambdaFunction`
625- * @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
626- * used as input for the `lambdaFunction`
627- * @param loopVarDataType the data type of the loop variable that used when iterate the collection,
628- * and used as input for the `lambdaFunction`
629- * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
630- * to handle collection elements.
631- * @param inputData An expression that when evaluated returns a collection object.
632- * @param collClass The type of the resulting collection.
633- * @param builderValue The name of the builder variable used to construct the resulting collection.
634- */
635- case class CollectObjects private (
636- loopValue : String ,
637- loopIsNull : String ,
638- loopVarDataType : DataType ,
639- lambdaFunction : Expression ,
640- inputData : Expression ,
641- collClass : Class [_],
642- builderValue : String ) extends Expression with NonSQLExpression {
643-
644- override def nullable : Boolean = inputData.nullable
645-
646- override def children : Seq [Expression ] = lambdaFunction :: inputData :: Nil
647-
648- override def eval (input : InternalRow ): Any =
649- throw new UnsupportedOperationException (" Only code-generated evaluation is supported" )
650-
651- override def dataType : DataType = ObjectType (collClass)
652-
653- override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
654- val collObjectName = s " ${collClass.getName}$$ .MODULE $$ "
655- val getBuilderVar = s " $collObjectName.newBuilder() "
656- val elementJavaType = ctx.javaType(loopVarDataType)
657- ctx.addMutableState(" boolean" , loopIsNull, " " )
658- ctx.addMutableState(elementJavaType, loopValue, " " )
659- val genInputData = inputData.genCode(ctx)
660- val genFunction = lambdaFunction.genCode(ctx)
661- val dataLength = ctx.freshName(" dataLength" )
662- val convertedArray = ctx.freshName(" convertedArray" )
663- val loopIndex = ctx.freshName(" loopIndex" )
664-
665- val convertedType = ctx.boxedType(lambdaFunction.dataType)
666-
667- // In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
668- // of input collection at runtime for this case.
669- val seq = ctx.freshName(" seq" )
670- val array = ctx.freshName(" array" )
671- val determineCollectionType = inputData.dataType match {
672- case ObjectType (cls) if cls == classOf [Object ] =>
673- val seqClass = classOf [Seq [_]].getName
674- s """
675- $seqClass $seq = null;
676- $elementJavaType[] $array = null;
677- if ( ${genInputData.value}.getClass().isArray()) {
678- $array = ( $elementJavaType[]) ${genInputData.value};
679- } else {
680- $seq = ( $seqClass) ${genInputData.value};
681- }
682- """
683- case _ => " "
684- }
685-
686- // The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
687- // When we want to apply MapObjects on it, we have to use it.
688- val inputDataType = inputData.dataType match {
689- case p : PythonUserDefinedType => p.sqlType
690- case _ => inputData.dataType
691- }
692-
693- val (getLength, getLoopVar) = inputDataType match {
694- case ObjectType (cls) if classOf [Seq [_]].isAssignableFrom(cls) =>
695- s " ${genInputData.value}.size() " -> s " ${genInputData.value}.apply( $loopIndex) "
696- case ObjectType (cls) if cls.isArray =>
697- s " ${genInputData.value}.length " -> s " ${genInputData.value}[ $loopIndex] "
698- case ObjectType (cls) if classOf [java.util.List [_]].isAssignableFrom(cls) =>
699- s " ${genInputData.value}.size() " -> s " ${genInputData.value}.get( $loopIndex) "
700- case ArrayType (et, _) =>
701- s " ${genInputData.value}.numElements() " -> ctx.getValue(genInputData.value, et, loopIndex)
702- case ObjectType (cls) if cls == classOf [Object ] =>
703- s " $seq == null ? $array.length : $seq.size() " ->
704- s " $seq == null ? $array[ $loopIndex] : $seq.apply( $loopIndex) "
705- }
706-
707- // Make a copy of the data if it's unsafe-backed
708- def makeCopyIfInstanceOf (clazz : Class [_ <: Any ], value : String ) =
709- s " $value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
710- val genFunctionValue = lambdaFunction.dataType match {
711- case StructType (_) => makeCopyIfInstanceOf(classOf [UnsafeRow ], genFunction.value)
712- case ArrayType (_, _) => makeCopyIfInstanceOf(classOf [UnsafeArrayData ], genFunction.value)
713- case MapType (_, _, _) => makeCopyIfInstanceOf(classOf [UnsafeMapData ], genFunction.value)
714- case _ => genFunction.value
715- }
716-
717- val loopNullCheck = inputDataType match {
718- case _ : ArrayType => s " $loopIsNull = ${genInputData.value}.isNullAt( $loopIndex); "
719- // The element of primitive array will never be null.
720- case ObjectType (cls) if cls.isArray && cls.getComponentType.isPrimitive =>
721- s " $loopIsNull = false "
722- case _ => s " $loopIsNull = $loopValue == null; "
723- }
724593
725594 val code = s """
726595 ${genInputData.code}
@@ -729,8 +598,7 @@ case class CollectObjects private(
729598 if (! ${genInputData.isNull}) {
730599 $determineCollectionType
731600 int $dataLength = $getLength;
732- ${classOf [Builder [_, _]].getName} $builderValue = $getBuilderVar;
733- $builderValue.sizeHint( $dataLength);
601+ $genInit
734602
735603 int $loopIndex = 0;
736604 while ( $loopIndex < $dataLength) {
@@ -739,15 +607,15 @@ case class CollectObjects private(
739607
740608 ${genFunction.code}
741609 if ( ${genFunction.isNull}) {
742- $builderValue . $$ plus $$ eq( null);
610+ ${genAssign( " null" )}
743611 } else {
744- $builderValue . $$ plus $$ eq( $ genFunctionValue);
612+ ${genAssign( genFunctionValue)}
745613 }
746614
747615 $loopIndex += 1;
748616 }
749617
750- ${ev.value} = ( ${collClass.getName} ) $builderValue .result();
618+ ${ev.value} = $genResult
751619 }
752620 """
753621 ev.copy(code = code, isNull = genInputData.isNull)
0 commit comments