Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 22 additions & 73 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,8 @@ package org.apache.spark.sql

import scala.reflect.ClassTag

import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
import org.apache.spark.sql.types.StructType

/**
* Used to convert a JVM object of type `T` to and from the internal Spark SQL representation.
Expand All @@ -49,83 +47,34 @@ object Encoders {
def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)

def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = {
tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2)]]
def tuple[T1, T2](
e1: Encoder[T1],
e2: Encoder[T2]): Encoder[(T1, T2)] = {
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
}

def tuple[T1, T2, T3](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
e1: Encoder[T1],
e2: Encoder[T2],
e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
}

def tuple[T1, T2, T3, T4](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3],
enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
tuple(Seq(enc1, enc2, enc3, enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
e1: Encoder[T1],
e2: Encoder[T2],
e3: Encoder[T3],
e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4))
}

def tuple[T1, T2, T3, T4, T5](
enc1: Encoder[T1],
enc2: Encoder[T2],
enc3: Encoder[T3],
enc4: Encoder[T4],
enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
tuple(Seq(enc1, enc2, enc3, enc4, enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
.asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}

private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
assert(encoders.length > 1)
// make sure all encoders are resolved, i.e. `Attribute` has been resolved to `BoundReference`.
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))

val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})

val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

val extractExpressions = encoders.map {
case e if e.flat => e.toRowExpressions.head
case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
case BoundReference(0, t: ObjectType, _) =>
Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
}

val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
enc.fromRowExpression.transformUp {
case BoundReference(ordinal, dt, _) =>
GetInternalRowField(BoundReference(index, enc.schema, nullable = true), ordinal, dt)
}
}
}

val constructExpression =
NewInstance(cls, constructExpressions, propagateNull = false, ObjectType(cls))

new ExpressionEncoder[Any](
schema,
flat = false,
extractExpressions,
constructExpression,
ClassTag(cls))
e1: Encoder[T1],
e2: Encoder[T2],
e3: Encoder[T3],
e4: Encoder[T4],
e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
ExpressionEncoder.tuple(
encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,47 +67,77 @@ object ExpressionEncoder {
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
encoders.foreach(_.assertUnresolved())

val schema =
StructType(
encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})
val schema = StructType(encoders.zipWithIndex.map {
case (e, i) => StructField(s"_${i + 1}", if (e.flat) e.schema.head.dataType else e.schema)
})

val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")

// Rebind the encoders to the nested schema.
val newConstructExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat => e.nested(i).fromRowExpression
case (e, i) => e.shift(i).fromRowExpression
val toRowExpressions = encoders.map {
case e if e.flat => e.toRowExpressions.head
case other => CreateStruct(other.toRowExpressions)
}.zipWithIndex.map { case (expr, index) =>
expr.transformUp {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not right. ExpressionEncoder.shift update ordinal of BoundReference with a delta, but here the i is the ordinal we wanna update, not a delta.

case BoundReference(0, t, _) =>
Invoke(
BoundReference(0, ObjectType(cls), nullable = true),
s"_${index + 1}",
t)
}
}

val constructExpression =
NewInstance(cls, newConstructExpressions, false, ObjectType(cls))

val input = BoundReference(0, ObjectType(cls), false)
val extractExpressions = encoders.zipWithIndex.map {
case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
}))
case (e, i) => e.toRowExpressions.head transformUp {
case b: BoundReference =>
Invoke(input, s"_${i + 1}", b.dataType, Nil)
val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
if (enc.flat) {
enc.fromRowExpression.transform {
case b: BoundReference => b.copy(ordinal = index)
}
} else {
val input = BoundReference(index, enc.schema, nullable = true)
enc.fromRowExpression.transformUp {
case UnresolvedAttribute(nameParts) =>
assert(nameParts.length == 1)
UnresolvedExtractValue(input, Literal(nameParts.head))
case BoundReference(ordinal, dt, _) => GetInternalRowField(input, ordinal, dt)
}
}
}

val fromRowExpression =
NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))

new ExpressionEncoder[Any](
schema,
false,
extractExpressions,
constructExpression,
ClassTag.apply(cls))
flat = false,
toRowExpressions,
fromRowExpression,
ClassTag(cls))
}

/** A helper for producing encoders of Tuple2 from other encoders. */
def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]

def tuple[T1, T2, T3](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]

def tuple[T1, T2, T3, T4](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3],
e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]

def tuple[T1, T2, T3, T4, T5](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2],
e3: ExpressionEncoder[T3],
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
}

/**
Expand Down Expand Up @@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
})
}

/**
* Returns a copy of this encoder where the expressions used to create an object given an
* input row have been modified to pull the object out from a nested struct, instead of the
* top level fields.
*/
private def nested(i: Int): ExpressionEncoder[T] = {
// We don't always know our input type at this point since it might be unresolved.
// We fill in null and it will get unbound to the actual attribute at this position.
val input = BoundReference(i, NullType, nullable = true)
copy(fromRowExpression = fromRowExpression transformUp {
case u: Attribute =>
UnresolvedExtractValue(input, Literal(u.name))
case b: BoundReference =>
GetStructField(
input,
StructField(s"i[${b.ordinal}]", b.dataType),
b.ordinal)
})
}

protected val attrs = toRowExpressions.flatMap(_.collect {
case _: UnresolvedAttribute => ""
case a: Attribute => s"#${a.exprId}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
productTest(("Seq[Seq[(Int, Int)]]",
Seq(Seq((1, 2)))))

encodeDecodeTest(
1 -> 10L,
ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
"tuple with 2 flat encoders")

encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], ProductEncoder[(Int, Long)]),
"tuple with 2 product encoders")

encodeDecodeTest(
(PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
"tuple with flat encoder and product encoder")

encodeDecodeTest(
(3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
"tuple with product encoder and flat encoder")

encodeDecodeTest(
(1, (10, 100L)),
{
val intEnc = FlatEncoder[Int]
val longEnc = FlatEncoder[Long]
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
},
"nested tuple encoder")

private def productTest[T <: Product : TypeTag](input: T): Unit = {
encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
}
Expand Down