Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,29 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
private def eliminateSerialization(p: LogicalPlan, isTop: Boolean): LogicalPlan = p transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// A workaround for SPARK-14803. Remove this after it is fixed.
if (d.outputObjectType.isInstanceOf[ObjectType] &&
d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
s.child
} else if (isTop) {
// If DeserializeToObject is at the top of logical plan, we don't need to preserve output
// expr id.
s.child
} else {
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p if p.children.exists(_.isInstanceOf[DeserializeToObject]) =>
p.withNewChildren(p.children.map(eliminateSerialization(_, false)))
case d: DeserializeToObject =>
eliminateSerialization(d, true)
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class EliminateSerializationSuite extends PlanTest {
val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
val optimized = Optimize.execute(plan)
val expected = input.select('obj.as("obj")).analyze
val expected = input
comparePlans(optimized, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(DatasetTransform.addOne(dataset), 2, 3, 4)
}

test("dataset.rdd with generic case class") {
val ds = Seq(Generic(1, 1.0), Generic(2, 2.0)).toDS
val ds2 = ds.map(g => Generic(g.id, g.value))
ds.rdd.map(r => r.id).count
ds2.rdd.map(r => r.id).count
}

test("runtime null check for RowEncoder") {
val schema = new StructType().add("i", IntegerType, nullable = false)
val df = sqlContext.range(10).map(l => {
Expand All @@ -676,6 +683,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
}
}

case class Generic[T](id: T, value: Double)

case class OtherTuple(_1: String, _2: Int)

case class TupleClass(data: (Int, String))
Expand Down