diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2c8e81ef17d7..592520c59a76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -89,16 +89,11 @@ object ExpressionEncoder { */ def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { // TODO: check if encoders length is more than 22 and throw exception for it. - encoders.foreach(_.assertUnresolved()) - val schema = StructType(encoders.zipWithIndex.map { - case (e, i) => - StructField(s"_${i + 1}", e.objSerializer.dataType, e.objSerializer.nullable) - }) - val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val newSerializerInput = BoundReference(0, ObjectType(cls), nullable = true) val serializers = encoders.zipWithIndex.map { case (enc, index) => val boundRefs = enc.objSerializer.collect { case b: BoundReference => b }.distinct assert(boundRefs.size == 1, "object serializer should have only one bound reference but " + @@ -106,42 +101,39 @@ object ExpressionEncoder { val originalInputObject = boundRefs.head val newInputObject = Invoke( - BoundReference(0, ObjectType(cls), nullable = true), + newSerializerInput, s"_${index + 1}", originalInputObject.dataType, returnNullable = originalInputObject.nullable) val newSerializer = enc.objSerializer.transformUp { - case b: BoundReference => newInputObject + case BoundReference(0, _, _) => newInputObject } Alias(newSerializer, s"_${index + 1}")() } + val newSerializer = CreateStruct(serializers) - val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => + val newDeserializerInput = GetColumnByOrdinal(0, newSerializer.dataType) + val deserializers = encoders.zipWithIndex.map { case (enc, index) => val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct assert(getColExprs.size == 1, "object deserializer should have only one " + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") - val input = GetStructField(GetColumnByOrdinal(0, schema), index) - val newDeserializer = enc.objDeserializer.transformUp { + val input = GetStructField(newDeserializerInput, index) + enc.objDeserializer.transformUp { case GetColumnByOrdinal(0, _) => input } - if (schema(index).nullable) { - If(IsNull(input), Literal.create(null, newDeserializer.dataType), newDeserializer) - } else { - newDeserializer - } } + val newDeserializer = NewInstance(cls, deserializers, ObjectType(cls), propagateNull = false) - val serializer = If(IsNull(BoundReference(0, ObjectType(cls), nullable = true)), - Literal.create(null, schema), CreateStruct(serializers)) - val deserializer = - NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) + def nullSafe(input: Expression, result: Expression): Expression = { + If(IsNull(input), Literal.create(null, result.dataType), result) + } new ExpressionEncoder[Any]( - serializer, - deserializer, + nullSafe(newSerializerInput, newSerializer), + nullSafe(newDeserializerInput, newDeserializer), ClassTag(cls)) }