Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll {

test("test implicit encoder resolution") {
val spark = session
import org.apache.spark.util.ArrayImplicits._
import spark.implicits._
def testImplicit[T: Encoder](expected: T): Unit = {
val encoder = encoderFor[T]
Expand Down Expand Up @@ -84,64 +85,74 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll {
testImplicit(booleans)
testImplicit(booleans.toSeq)
testImplicit(booleans.toSeq)(newBooleanSeqEncoder)
testImplicit(booleans.toImmutableArraySeq)

val bytes = Array(76.toByte, 59.toByte, 121.toByte)
testImplicit(bytes.head)
testImplicit(java.lang.Byte.valueOf(bytes.head))
testImplicit(bytes)
testImplicit(bytes.toSeq)
testImplicit(bytes.toSeq)(newByteSeqEncoder)
testImplicit(bytes.toImmutableArraySeq)

val shorts = Array(21.toShort, (-213).toShort, 14876.toShort)
testImplicit(shorts.head)
testImplicit(java.lang.Short.valueOf(shorts.head))
testImplicit(shorts)
testImplicit(shorts.toSeq)
testImplicit(shorts.toSeq)(newShortSeqEncoder)
testImplicit(shorts.toImmutableArraySeq)

val ints = Array(4, 6, 5)
testImplicit(ints.head)
testImplicit(java.lang.Integer.valueOf(ints.head))
testImplicit(ints)
testImplicit(ints.toSeq)
testImplicit(ints.toSeq)(newIntSeqEncoder)
testImplicit(ints.toImmutableArraySeq)

val longs = Array(System.nanoTime(), System.currentTimeMillis())
testImplicit(longs.head)
testImplicit(java.lang.Long.valueOf(longs.head))
testImplicit(longs)
testImplicit(longs.toSeq)
testImplicit(longs.toSeq)(newLongSeqEncoder)
testImplicit(longs.toImmutableArraySeq)

val floats = Array(3f, 10.9f)
testImplicit(floats.head)
testImplicit(java.lang.Float.valueOf(floats.head))
testImplicit(floats)
testImplicit(floats.toSeq)
testImplicit(floats.toSeq)(newFloatSeqEncoder)
testImplicit(floats.toImmutableArraySeq)

val doubles = Array(23.78d, -329.6d)
testImplicit(doubles.head)
testImplicit(java.lang.Double.valueOf(doubles.head))
testImplicit(doubles)
testImplicit(doubles.toSeq)
testImplicit(doubles.toSeq)(newDoubleSeqEncoder)
testImplicit(doubles.toImmutableArraySeq)

val strings = Array("foo", "baz", "bar")
testImplicit(strings.head)
testImplicit(strings)
testImplicit(strings.toSeq)
testImplicit(strings.toSeq)(newStringSeqEncoder)
testImplicit(strings.toImmutableArraySeq)

val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0))
testImplicit(myTypes.head)
testImplicit(myTypes)
testImplicit(myTypes.toSeq)
testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType])
testImplicit(myTypes.toImmutableArraySeq)

// Others.
val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18)
testImplicit(decimal)
testImplicit(Array(decimal).toImmutableArraySeq)
testImplicit(BigDecimal(decimal))
testImplicit(Date.valueOf(LocalDate.now()))
testImplicit(LocalDate.now())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.time._
import java.util
import java.util.{List => JList, Locale, Map => JMap}

import scala.collection.immutable
import scala.collection.mutable
import scala.reflect.ClassTag

Expand All @@ -46,6 +47,7 @@ import org.apache.spark.sql.types.Decimal
*/
object ArrowDeserializers {
import ArrowEncoderUtils._
import org.apache.spark.util.ArrayImplicits._

/**
* Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and
Expand Down Expand Up @@ -222,6 +224,13 @@ object ArrowDeserializers {
ScalaCollectionUtils.wrap(array)
}
}
} else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) {
new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) {
def value(i: Int): immutable.ArraySeq[Any] = {
val array = getArray(vector, i, deserializer)(element.clsTag)
array.asInstanceOf[Array[_]].toImmutableArraySeq
}
}
} else if (isSubClass(Classes.ITERABLE, tag)) {
val companion = ScalaCollectionUtils.getIterableCompanion(tag)
new VectorFieldDeserializer[Iterable[Any], ListVector](v) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.client.arrow

import scala.collection.immutable
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
Expand All @@ -26,6 +27,7 @@ import org.apache.arrow.vector.complex.StructVector
private[arrow] object ArrowEncoderUtils {
object Classes {
val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]]
val IMMUTABLE_ARRAY_SEQ: Class[_] = classOf[immutable.ArraySeq[_]]
val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]]
val MAP: Class[_] = classOf[scala.collection.Map[_, _]]
val JLIST: Class[_] = classOf[java.util.List[_]]
Expand Down