diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 6d94764f1bfa..bfe8ae6abc1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -406,7 +406,7 @@ case class WrapOption(child: Expression, optType: DataType) } /** - * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed + * A placeholder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ case class LambdaVariable( @@ -421,6 +421,27 @@ case class LambdaVariable( } } +/** + * When constructing [[MapObjects]], the element type must be given, which may not be available + * before analysis. This class acts like a placeholder for [[MapObjects]], and will be replaced by + * [[MapObjects]] during analysis after the input data is resolved. + * Note that, ideally we should not serialize and send unresolved expressions to executors, but + * users may accidentally do this(e.g. mistakenly reference an encoder instance when implementing + * Aggregator). Here we mark `function` as transient because it may reference scala Type, which is + * not serializable. Then even users mistakenly reference unresolved expression and serialize it, + * it's just a performance issue(more network traffic), and will not fail. + */ +case class UnresolvedMapObjects( + @transient function: Expression => Expression, + child: Expression, + customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { + override lazy val resolved = false + + override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { + throw new UnsupportedOperationException("not resolved") + } +} + object MapObjects { private val curId = new java.util.concurrent.atomic.AtomicInteger() @@ -442,20 +463,8 @@ object MapObjects { val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - val builderValue = s"MapObjects_builderValue$id" - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, - customCollectionCls, builderValue) - } -} - -case class UnresolvedMapObjects( - function: Expression => Expression, - child: Expression, - customCollectionCls: Option[Class[_]] = None) extends UnaryExpression with Unevaluable { - override lazy val resolved = false - - override def dataType: DataType = customCollectionCls.map(ObjectType.apply).getOrElse { - throw new UnsupportedOperationException("not resolved") + MapObjects( + loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } } @@ -482,8 +491,6 @@ case class UnresolvedMapObjects( * @param inputData An expression that when evaluated returns a collection object. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) - * @param builderValue The name of the builder variable used to construct the resulting collection - * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, @@ -491,8 +498,7 @@ case class MapObjects private( loopVarDataType: DataType, lambdaFunction: Expression, inputData: Expression, - customCollectionCls: Option[Class[_]], - builderValue: String) extends Expression with NonSQLExpression { + customCollectionCls: Option[Class[_]]) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -590,15 +596,15 @@ case class MapObjects private( customCollectionCls match { case Some(cls) => // collection - val collObjectName = s"${cls.getName}$$.MODULE$$" - val getBuilderVar = s"$collObjectName.newBuilder()" + val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" + val builder = ctx.freshName("collectionBuilder") ( s""" - ${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; - $builderValue.sizeHint($dataLength); + ${classOf[Builder[_, _]].getName} $builder = $getBuilder; + $builder.sizeHint($dataLength); """, - genValue => s"$builderValue.$$plus$$eq($genValue);", - s"(${cls.getName}) $builderValue.result();" + genValue => s"$builder.$$plus$$eq($genValue);", + s"(${cls.getName}) $builder.result();" ) case None => // array