diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a3ab89dc7114..7001fa12cc7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -160,18 +160,29 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + private def eliminateSerialization(p: LogicalPlan, isTop: Boolean): LogicalPlan = p transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => // A workaround for SPARK-14803. Remove this after it is fixed. if (d.outputObjectType.isInstanceOf[ObjectType] && d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { s.child + } else if (isTop) { + // If DeserializeToObject is at the top of logical plan, we don't need to preserve output + // expr id. + s.child } else { // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) } + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p if p.children.exists(_.isInstanceOf[DeserializeToObject]) => + p.withNewChildren(p.children.map(eliminateSerialization(_, false))) + case d: DeserializeToObject => + eliminateSerialization(d, true) case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 3c033ddc374c..79835edf3f85 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest { val input = LocalRelation('obj.obj(classOf[(Int, Int)])) val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) - val expected = input.select('obj.as("obj")).analyze + val expected = input comparePlans(optimized, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 3cb4e52c6d41..0ae23e282b02 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -659,6 +659,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4) } + test("dataset.rdd with generic case class") { + val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS + val ds2 = ds.map(g => Generic(g.id, g.value)) + ds.rdd.map(r => r.id).count + ds2.rdd.map(r => r.id).count + } + test("runtime null check for RowEncoder") { val schema = new StructType().add("i", IntegerType, nullable = false) val df = sqlContext.range(10).map(l => { @@ -676,6 +683,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext { } } +case class Generic[T](id: T, value: Double) + case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String))