diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index b2c13850a13a0..3e4704b6ab8e0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -52,6 +52,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { test("test implicit encoder resolution") { val spark = session + import org.apache.spark.util.ArrayImplicits._ import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { val encoder = encoderFor[T] @@ -84,6 +85,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(booleans) testImplicit(booleans.toSeq) testImplicit(booleans.toSeq)(newBooleanSeqEncoder) + testImplicit(booleans.toImmutableArraySeq) val bytes = Array(76.toByte, 59.toByte, 121.toByte) testImplicit(bytes.head) @@ -91,6 +93,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(bytes) testImplicit(bytes.toSeq) testImplicit(bytes.toSeq)(newByteSeqEncoder) + testImplicit(bytes.toImmutableArraySeq) val shorts = Array(21.toShort, (-213).toShort, 14876.toShort) testImplicit(shorts.head) @@ -98,6 +101,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(shorts) testImplicit(shorts.toSeq) testImplicit(shorts.toSeq)(newShortSeqEncoder) + testImplicit(shorts.toImmutableArraySeq) val ints = Array(4, 6, 5) testImplicit(ints.head) @@ -105,6 +109,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(ints) testImplicit(ints.toSeq) testImplicit(ints.toSeq)(newIntSeqEncoder) + testImplicit(ints.toImmutableArraySeq) val longs = Array(System.nanoTime(), System.currentTimeMillis()) testImplicit(longs.head) @@ -112,6 +117,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(longs) testImplicit(longs.toSeq) testImplicit(longs.toSeq)(newLongSeqEncoder) + testImplicit(longs.toImmutableArraySeq) val floats = Array(3f, 10.9f) testImplicit(floats.head) @@ -119,6 +125,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(floats) testImplicit(floats.toSeq) testImplicit(floats.toSeq)(newFloatSeqEncoder) + testImplicit(floats.toImmutableArraySeq) val doubles = Array(23.78d, -329.6d) testImplicit(doubles.head) @@ -126,22 +133,26 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(doubles) testImplicit(doubles.toSeq) testImplicit(doubles.toSeq)(newDoubleSeqEncoder) + testImplicit(doubles.toImmutableArraySeq) val strings = Array("foo", "baz", "bar") testImplicit(strings.head) testImplicit(strings) testImplicit(strings.toSeq) testImplicit(strings.toSeq)(newStringSeqEncoder) + testImplicit(strings.toImmutableArraySeq) val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0)) testImplicit(myTypes.head) testImplicit(myTypes) testImplicit(myTypes.toSeq) testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType]) + testImplicit(myTypes.toImmutableArraySeq) // Others. val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18) testImplicit(decimal) + testImplicit(Array(decimal).toImmutableArraySeq) testImplicit(BigDecimal(decimal)) testImplicit(Date.valueOf(LocalDate.now())) testImplicit(LocalDate.now()) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 52461d1ebaeaa..ac9619487f02c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -24,6 +24,7 @@ import java.time._ import java.util import java.util.{List => JList, Locale, Map => JMap} +import scala.collection.immutable import scala.collection.mutable import scala.reflect.ClassTag @@ -46,6 +47,7 @@ import org.apache.spark.sql.types.Decimal */ object ArrowDeserializers { import ArrowEncoderUtils._ + import org.apache.spark.util.ArrayImplicits._ /** * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and @@ -222,6 +224,13 @@ object ArrowDeserializers { ScalaCollectionUtils.wrap(array) } } + } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { + new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): immutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + array.asInstanceOf[Array[_]].toImmutableArraySeq + } + } } else if (isSubClass(Classes.ITERABLE, tag)) { val companion = ScalaCollectionUtils.getIterableCompanion(tag) new VectorFieldDeserializer[Iterable[Any], ListVector](v) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index 6d1325b55d414..5b1539e39f4f4 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.client.arrow +import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -26,6 +27,7 @@ import org.apache.arrow.vector.complex.StructVector private[arrow] object ArrowEncoderUtils { object Classes { val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]] + val IMMUTABLE_ARRAY_SEQ: Class[_] = classOf[immutable.ArraySeq[_]] val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] val MAP: Class[_] = classOf[scala.collection.Map[_, _]] val JLIST: Class[_] = classOf[java.util.List[_]]