Skip to content

Commit b5f87bd

Browse files
committed
Changes based on code review
Merge CollectObjects with MapObjects Remove SequenceBenchmark
1 parent 85edddd commit b5f87bd

File tree

3 files changed

+45
-251
lines changed

3 files changed

+45
-251
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ object ScalaReflection extends ScalaReflection {
311311
case NoSymbol => classOf[Seq[_]]
312312
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
313313
}
314-
CollectObjects(mapFunction, getPath, dataType, cls)
314+
MapObjects(mapFunction, getPath, dataType, cls)
315315

316316
case t if t <:< localTypeOf[Map[_, _]] =>
317317
// TODO: add walked type path for map

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 44 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -430,24 +430,33 @@ object MapObjects {
430430
* @param function The function applied on the collection elements.
431431
* @param inputData An expression that when evaluated returns a collection object.
432432
* @param elementType The data type of elements in the collection.
433+
* @param collClass The class of the resulting collection
433434
*/
434435
def apply(
435436
function: Expression => Expression,
436437
inputData: Expression,
437-
elementType: DataType): MapObjects = {
438-
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
439-
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
438+
elementType: DataType,
439+
collClass: Class[_] = classOf[Array[_]]): MapObjects = {
440+
val id = curId.getAndIncrement()
441+
val loopValue = s"MapObjects_loopValue$id"
442+
val loopIsNull = s"MapObjects_loopIsNull$id"
440443
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
441-
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
444+
val builderValue = s"MapObjects_builderValue$id"
445+
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
446+
collClass, builderValue)
442447
}
443448
}
444449

445450
/**
446451
* Applies the given expression to every element of a collection of items, returning the result
447-
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
448-
* is expressed using catalyst expressions.
452+
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
453+
* function is expressed using catalyst expressions.
454+
*
455+
* The type of the result is determined as follows:
456+
* - ArrayType - when collClass is an array class
457+
* - ObjectType(collClass) - when collClass is a collection class
449458
*
450-
* The following collection ObjectTypes are currently supported:
459+
* The following collection ObjectTypes are currently supported on input:
451460
* Seq, Array, ArrayData, java.util.List
452461
*
453462
* @param loopValue the name of the loop variable that used when iterate the collection, and used
@@ -459,13 +468,18 @@ object MapObjects {
459468
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
460469
* to handle collection elements.
461470
* @param inputData An expression that when evaluated returns a collection object.
471+
* @param collClass The class of the resulting collection
472+
* @param builderValue The name of the builder variable used to construct the resulting collection
473+
* (used only when returning ObjectType)
462474
*/
463475
case class MapObjects private(
464476
loopValue: String,
465477
loopIsNull: String,
466478
loopVarDataType: DataType,
467479
lambdaFunction: Expression,
468-
inputData: Expression) extends Expression with NonSQLExpression {
480+
inputData: Expression,
481+
collClass: Class[_],
482+
builderValue: String) extends Expression with NonSQLExpression {
469483

470484
override def nullable: Boolean = inputData.nullable
471485

@@ -475,7 +489,8 @@ case class MapObjects private(
475489
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
476490

477491
override def dataType: DataType =
478-
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
492+
if (!collClass.isArray) ObjectType(collClass)
493+
else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
479494

480495
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
481496
val elementJavaType = ctx.javaType(loopVarDataType)
@@ -558,169 +573,23 @@ case class MapObjects private(
558573
case _ => s"$loopIsNull = $loopValue == null;"
559574
}
560575

561-
val code = s"""
562-
${genInputData.code}
563-
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
564-
565-
if (!${genInputData.isNull}) {
566-
$determineCollectionType
567-
$convertedType[] $convertedArray = null;
568-
int $dataLength = $getLength;
569-
$convertedArray = $arrayConstructor;
570-
571-
int $loopIndex = 0;
572-
while ($loopIndex < $dataLength) {
573-
$loopValue = ($elementJavaType) ($getLoopVar);
574-
$loopNullCheck
575-
576-
${genFunction.code}
577-
if (${genFunction.isNull}) {
578-
$convertedArray[$loopIndex] = null;
579-
} else {
580-
$convertedArray[$loopIndex] = $genFunctionValue;
581-
}
582-
583-
$loopIndex += 1;
584-
}
585-
586-
${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
576+
val (genInit, genAssign, genResult): (String, String => String, String) =
577+
if (collClass.isArray) {
578+
// array
579+
(s"""$convertedType[] $convertedArray = null;
580+
$convertedArray = $arrayConstructor;""",
581+
genValue => s"$convertedArray[$loopIndex] = $genValue;",
582+
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
583+
} else {
584+
// collection
585+
val collObjectName = s"${collClass.getName}$$.MODULE$$"
586+
val getBuilderVar = s"$collObjectName.newBuilder()"
587+
588+
(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
589+
$builderValue.sizeHint($dataLength);""",
590+
genValue => s"$builderValue.$$plus$$eq($genValue);",
591+
s"(${collClass.getName}) $builderValue.result();")
587592
}
588-
"""
589-
ev.copy(code = code, isNull = genInputData.isNull)
590-
}
591-
}
592-
593-
object CollectObjects {
594-
private val curId = new java.util.concurrent.atomic.AtomicInteger()
595-
596-
/**
597-
* Construct an instance of CollectObjects case class.
598-
*
599-
* @param function The function applied on the collection elements.
600-
* @param inputData An expression that when evaluated returns a collection object.
601-
* @param elementType The data type of elements in the collection.
602-
* @param collClass The type of the resulting collection.
603-
*/
604-
def apply(
605-
function: Expression => Expression,
606-
inputData: Expression,
607-
elementType: DataType,
608-
collClass: Class[_]): CollectObjects = {
609-
val loopValue = "CollectObjects_loopValue" + curId.getAndIncrement()
610-
val loopIsNull = "CollectObjects_loopIsNull" + curId.getAndIncrement()
611-
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
612-
val builderValue = "CollectObjects_builderValue" + curId.getAndIncrement()
613-
CollectObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
614-
collClass, builderValue)
615-
}
616-
}
617-
618-
/**
619-
* An equivalent to the [[MapObjects]] case class but returning an ObjectType containing
620-
* a Scala collection constructed using the associated builder, obtained by calling `newBuilder`
621-
* on the collection's companion object.
622-
*
623-
* @param loopValue the name of the loop variable that used when iterate the collection, and used
624-
* as input for the `lambdaFunction`
625-
* @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
626-
* used as input for the `lambdaFunction`
627-
* @param loopVarDataType the data type of the loop variable that used when iterate the collection,
628-
* and used as input for the `lambdaFunction`
629-
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
630-
* to handle collection elements.
631-
* @param inputData An expression that when evaluated returns a collection object.
632-
* @param collClass The type of the resulting collection.
633-
* @param builderValue The name of the builder variable used to construct the resulting collection.
634-
*/
635-
case class CollectObjects private(
636-
loopValue: String,
637-
loopIsNull: String,
638-
loopVarDataType: DataType,
639-
lambdaFunction: Expression,
640-
inputData: Expression,
641-
collClass: Class[_],
642-
builderValue: String) extends Expression with NonSQLExpression {
643-
644-
override def nullable: Boolean = inputData.nullable
645-
646-
override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil
647-
648-
override def eval(input: InternalRow): Any =
649-
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
650-
651-
override def dataType: DataType = ObjectType(collClass)
652-
653-
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
654-
val collObjectName = s"${collClass.getName}$$.MODULE$$"
655-
val getBuilderVar = s"$collObjectName.newBuilder()"
656-
val elementJavaType = ctx.javaType(loopVarDataType)
657-
ctx.addMutableState("boolean", loopIsNull, "")
658-
ctx.addMutableState(elementJavaType, loopValue, "")
659-
val genInputData = inputData.genCode(ctx)
660-
val genFunction = lambdaFunction.genCode(ctx)
661-
val dataLength = ctx.freshName("dataLength")
662-
val convertedArray = ctx.freshName("convertedArray")
663-
val loopIndex = ctx.freshName("loopIndex")
664-
665-
val convertedType = ctx.boxedType(lambdaFunction.dataType)
666-
667-
// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
668-
// of input collection at runtime for this case.
669-
val seq = ctx.freshName("seq")
670-
val array = ctx.freshName("array")
671-
val determineCollectionType = inputData.dataType match {
672-
case ObjectType(cls) if cls == classOf[Object] =>
673-
val seqClass = classOf[Seq[_]].getName
674-
s"""
675-
$seqClass $seq = null;
676-
$elementJavaType[] $array = null;
677-
if (${genInputData.value}.getClass().isArray()) {
678-
$array = ($elementJavaType[]) ${genInputData.value};
679-
} else {
680-
$seq = ($seqClass) ${genInputData.value};
681-
}
682-
"""
683-
case _ => ""
684-
}
685-
686-
// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
687-
// When we want to apply MapObjects on it, we have to use it.
688-
val inputDataType = inputData.dataType match {
689-
case p: PythonUserDefinedType => p.sqlType
690-
case _ => inputData.dataType
691-
}
692-
693-
val (getLength, getLoopVar) = inputDataType match {
694-
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
695-
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
696-
case ObjectType(cls) if cls.isArray =>
697-
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
698-
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
699-
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
700-
case ArrayType(et, _) =>
701-
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
702-
case ObjectType(cls) if cls == classOf[Object] =>
703-
s"$seq == null ? $array.length : $seq.size()" ->
704-
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
705-
}
706-
707-
// Make a copy of the data if it's unsafe-backed
708-
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
709-
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
710-
val genFunctionValue = lambdaFunction.dataType match {
711-
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
712-
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
713-
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
714-
case _ => genFunction.value
715-
}
716-
717-
val loopNullCheck = inputDataType match {
718-
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
719-
// The element of primitive array will never be null.
720-
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
721-
s"$loopIsNull = false"
722-
case _ => s"$loopIsNull = $loopValue == null;"
723-
}
724593

725594
val code = s"""
726595
${genInputData.code}
@@ -729,8 +598,7 @@ case class CollectObjects private(
729598
if (!${genInputData.isNull}) {
730599
$determineCollectionType
731600
int $dataLength = $getLength;
732-
${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
733-
$builderValue.sizeHint($dataLength);
601+
$genInit
734602

735603
int $loopIndex = 0;
736604
while ($loopIndex < $dataLength) {
@@ -739,15 +607,15 @@ case class CollectObjects private(
739607

740608
${genFunction.code}
741609
if (${genFunction.isNull}) {
742-
$builderValue.$$plus$$eq(null);
610+
${genAssign("null")}
743611
} else {
744-
$builderValue.$$plus$$eq($genFunctionValue);
612+
${genAssign(genFunctionValue)}
745613
}
746614

747615
$loopIndex += 1;
748616
}
749617

750-
${ev.value} = (${collClass.getName}) $builderValue.result();
618+
${ev.value} = $genResult
751619
}
752620
"""
753621
ev.copy(code = code, isNull = genInputData.isNull)

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SequenceBenchmark.scala

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)