Skip to content

Commit 99594b2

Browse files
committed
[SPARK-13094][SQL] Add encoders for seq/array of primitives
Author: Michael Armbrust <michael@databricks.com> Closes #11014 from marmbrus/seqEncoders. (cherry picked from commit 29d9218) Signed-off-by: Michael Armbrust <michael@databricks.com>
1 parent bd8efba commit 99594b2

File tree

3 files changed

+91
-2
lines changed

3 files changed

+91
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ abstract class SQLImplicits {
4040
/** @since 1.6.0 */
4141
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
4242

43+
// Primitives
44+
4345
/** @since 1.6.0 */
4446
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
4547

@@ -57,13 +59,72 @@ abstract class SQLImplicits {
5759

5860
/** @since 1.6.0 */
5961
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
60-
/** @since 1.6.0 */
6162

63+
/** @since 1.6.0 */
6264
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
6365

6466
/** @since 1.6.0 */
6567
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
6668

69+
// Seqs
70+
71+
/** @since 1.6.1 */
72+
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
73+
74+
/** @since 1.6.1 */
75+
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
76+
77+
/** @since 1.6.1 */
78+
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
79+
80+
/** @since 1.6.1 */
81+
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
82+
83+
/** @since 1.6.1 */
84+
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
85+
86+
/** @since 1.6.1 */
87+
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
88+
89+
/** @since 1.6.1 */
90+
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
91+
92+
/** @since 1.6.1 */
93+
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
94+
95+
/** @since 1.6.1 */
96+
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
97+
98+
// Arrays
99+
100+
/** @since 1.6.1 */
101+
implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()
102+
103+
/** @since 1.6.1 */
104+
implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()
105+
106+
/** @since 1.6.1 */
107+
implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()
108+
109+
/** @since 1.6.1 */
110+
implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()
111+
112+
/** @since 1.6.1 */
113+
implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()
114+
115+
/** @since 1.6.1 */
116+
implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()
117+
118+
/** @since 1.6.1 */
119+
implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()
120+
121+
/** @since 1.6.1 */
122+
implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()
123+
124+
/** @since 1.6.1 */
125+
implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
126+
ExpressionEncoder()
127+
67128
/**
68129
* Creates a [[Dataset]] from an RDD.
69130
* @since 1.6.0

sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
105105
agged,
106106
"1", "abc", "3", "xyz", "5", "hello")
107107
}
108+
109+
test("Arrays and Lists") {
110+
checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
111+
checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
112+
checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
113+
checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
114+
checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
115+
checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
116+
checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
117+
checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
118+
checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
119+
120+
checkAnswer(Seq(Array(1)).toDS(), Array(1))
121+
checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
122+
checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
123+
checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
124+
checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
125+
checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
126+
checkAnswer(Seq(Array(true)).toDS(), Array(true))
127+
checkAnswer(Seq(Array("test")).toDS(), Array("test"))
128+
checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
129+
}
108130
}

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,13 @@ abstract class QueryTest extends PlanTest {
9494
""".stripMargin, e)
9595
}
9696

97-
if (decoded != expectedAnswer.toSet) {
97+
// Handle the case where the return type is an array
98+
val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
99+
def normalEquality = decoded == expectedAnswer.toSet
100+
def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
101+
def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
102+
103+
if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
98104
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
99105
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
100106

0 commit comments

Comments
 (0)