From f5892066ed000c0c806e7502e513df89ef508fc9 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 8 Jan 2024 14:19:11 +0800 Subject: [PATCH 1/3] [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in ArrowDeserializers --- .../org/apache/spark/sql/SQLImplicitsTestSuite.scala | 12 ++++++++++++ .../sql/connect/client/arrow/ArrowDeserializer.scala | 8 ++++++++ .../sql/connect/client/arrow/ArrowEncoderUtils.scala | 2 ++ .../connect/client/arrow/ScalaCollectionUtils.scala | 6 +++++- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 2 +- 5 files changed, 28 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index b2c13850a13a0..3afc56051d923 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -21,6 +21,8 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.time.temporal.ChronoUnit import java.util.concurrent.atomic.AtomicLong +import scala.collection.immutable + import io.grpc.inprocess.InProcessChannelBuilder import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils @@ -84,6 +86,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(booleans) testImplicit(booleans.toSeq) testImplicit(booleans.toSeq)(newBooleanSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(booleans)) val bytes = Array(76.toByte, 59.toByte, 121.toByte) testImplicit(bytes.head) @@ -91,6 +94,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(bytes) testImplicit(bytes.toSeq) testImplicit(bytes.toSeq)(newByteSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(bytes)) val shorts = Array(21.toShort, (-213).toShort, 14876.toShort) testImplicit(shorts.head) @@ -98,6 +102,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(shorts) testImplicit(shorts.toSeq) testImplicit(shorts.toSeq)(newShortSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(shorts)) val ints = Array(4, 6, 5) testImplicit(ints.head) @@ -105,6 +110,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(ints) testImplicit(ints.toSeq) testImplicit(ints.toSeq)(newIntSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(ints)) val longs = Array(System.nanoTime(), System.currentTimeMillis()) testImplicit(longs.head) @@ -112,6 +118,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(longs) testImplicit(longs.toSeq) testImplicit(longs.toSeq)(newLongSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(longs)) val floats = Array(3f, 10.9f) testImplicit(floats.head) @@ -119,6 +126,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(floats) testImplicit(floats.toSeq) testImplicit(floats.toSeq)(newFloatSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(floats)) val doubles = Array(23.78d, -329.6d) testImplicit(doubles.head) @@ -126,22 +134,26 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(doubles) testImplicit(doubles.toSeq) testImplicit(doubles.toSeq)(newDoubleSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(doubles)) val strings = Array("foo", "baz", "bar") testImplicit(strings.head) testImplicit(strings) testImplicit(strings.toSeq) testImplicit(strings.toSeq)(newStringSeqEncoder) + testImplicit(immutable.ArraySeq.unsafeWrapArray(strings)) val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0)) testImplicit(myTypes.head) testImplicit(myTypes) testImplicit(myTypes.toSeq) testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType]) + testImplicit(immutable.ArraySeq.unsafeWrapArray(myTypes)) // Others. val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18) testImplicit(decimal) + testImplicit(immutable.ArraySeq.unsafeWrapArray(Array(decimal))) testImplicit(BigDecimal(decimal)) testImplicit(Date.valueOf(LocalDate.now())) testImplicit(LocalDate.now()) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 52461d1ebaeaa..513fc30a584ad 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -24,6 +24,7 @@ import java.time._ import java.util import java.util.{List => JList, Locale, Map => JMap} +import scala.collection.immutable import scala.collection.mutable import scala.reflect.ClassTag @@ -222,6 +223,13 @@ object ArrowDeserializers { ScalaCollectionUtils.wrap(array) } } + } else if (isSubClass(Classes.IMMUTABLE_ARRAY_SEQ, tag)) { + new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { + def value(i: Int): immutable.ArraySeq[Any] = { + val array = getArray(vector, i, deserializer)(element.clsTag) + ScalaCollectionUtils.toImmutableArraySeq(array) + } + } } else if (isSubClass(Classes.ITERABLE, tag)) { val companion = ScalaCollectionUtils.getIterableCompanion(tag) new VectorFieldDeserializer[Iterable[Any], ListVector](v) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala index 6d1325b55d414..5b1539e39f4f4 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.connect.client.arrow +import scala.collection.immutable import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag @@ -26,6 +27,7 @@ import org.apache.arrow.vector.complex.StructVector private[arrow] object ArrowEncoderUtils { object Classes { val MUTABLE_ARRAY_SEQ: Class[_] = classOf[mutable.ArraySeq[_]] + val IMMUTABLE_ARRAY_SEQ: Class[_] = classOf[immutable.ArraySeq[_]] val ITERABLE: Class[_] = classOf[scala.collection.Iterable[_]] val MAP: Class[_] = classOf[scala.collection.Map[_, _]] val JLIST: Class[_] = classOf[java.util.List[_]] diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala index 8bc4c0435d0d3..16157ded5e2c1 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.connect.client.arrow -import scala.collection.{mutable, IterableFactory, MapFactory} +import scala.collection.{immutable, mutable, IterableFactory, MapFactory} import scala.reflect.ClassTag import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion @@ -34,4 +34,8 @@ private[arrow] object ScalaCollectionUtils { def wrap[T](array: AnyRef): mutable.ArraySeq[T] = { mutable.ArraySeq.make(array.asInstanceOf[Array[T]]) } + + def toImmutableArraySeq[T](array: AnyRef): immutable.ArraySeq[T] = { + immutable.ArraySeq.unsafeWrapArray(array.asInstanceOf[Array[T]]) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 2bd649ea85e52..7952e0f174b74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -840,7 +840,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { Row(ArrayBuffer(100))) val myUdf2 = udf((a: immutable.ArraySeq[Int]) => - immutable.ArraySeq.unsafeWrapArray[Int](a.appended(5).appended(6).toArray)) + immutable.ArraySeq.unsafeWrapArray[Int]((a :+ 5 :+ 6).toArray)) checkAnswer(Seq(Array(1, 2, 3)) .toDF("col") .select(myUdf2(Column("col"))), From 430572b156253661bfbba153649f10e248e556dc Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 8 Jan 2024 17:28:54 +0800 Subject: [PATCH 2/3] [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in ArrowDeserializers --- .../spark/sql/SQLImplicitsTestSuite.scala | 23 +++++++++---------- .../client/arrow/ArrowDeserializer.scala | 3 ++- .../client/arrow/ScalaCollectionUtils.scala | 6 +---- 3 files changed, 14 insertions(+), 18 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 3afc56051d923..3e4704b6ab8e0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -21,8 +21,6 @@ import java.time.{Duration, Instant, LocalDate, LocalDateTime, Period} import java.time.temporal.ChronoUnit import java.util.concurrent.atomic.AtomicLong -import scala.collection.immutable - import io.grpc.inprocess.InProcessChannelBuilder import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils @@ -54,6 +52,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { test("test implicit encoder resolution") { val spark = session + import org.apache.spark.util.ArrayImplicits._ import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { val encoder = encoderFor[T] @@ -86,7 +85,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(booleans) testImplicit(booleans.toSeq) testImplicit(booleans.toSeq)(newBooleanSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(booleans)) + testImplicit(booleans.toImmutableArraySeq) val bytes = Array(76.toByte, 59.toByte, 121.toByte) testImplicit(bytes.head) @@ -94,7 +93,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(bytes) testImplicit(bytes.toSeq) testImplicit(bytes.toSeq)(newByteSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(bytes)) + testImplicit(bytes.toImmutableArraySeq) val shorts = Array(21.toShort, (-213).toShort, 14876.toShort) testImplicit(shorts.head) @@ -102,7 +101,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(shorts) testImplicit(shorts.toSeq) testImplicit(shorts.toSeq)(newShortSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(shorts)) + testImplicit(shorts.toImmutableArraySeq) val ints = Array(4, 6, 5) testImplicit(ints.head) @@ -110,7 +109,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(ints) testImplicit(ints.toSeq) testImplicit(ints.toSeq)(newIntSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(ints)) + testImplicit(ints.toImmutableArraySeq) val longs = Array(System.nanoTime(), System.currentTimeMillis()) testImplicit(longs.head) @@ -118,7 +117,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(longs) testImplicit(longs.toSeq) testImplicit(longs.toSeq)(newLongSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(longs)) + testImplicit(longs.toImmutableArraySeq) val floats = Array(3f, 10.9f) testImplicit(floats.head) @@ -126,7 +125,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(floats) testImplicit(floats.toSeq) testImplicit(floats.toSeq)(newFloatSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(floats)) + testImplicit(floats.toImmutableArraySeq) val doubles = Array(23.78d, -329.6d) testImplicit(doubles.head) @@ -134,26 +133,26 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { testImplicit(doubles) testImplicit(doubles.toSeq) testImplicit(doubles.toSeq)(newDoubleSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(doubles)) + testImplicit(doubles.toImmutableArraySeq) val strings = Array("foo", "baz", "bar") testImplicit(strings.head) testImplicit(strings) testImplicit(strings.toSeq) testImplicit(strings.toSeq)(newStringSeqEncoder) - testImplicit(immutable.ArraySeq.unsafeWrapArray(strings)) + testImplicit(strings.toImmutableArraySeq) val myTypes = Array(MyType(12L, Math.E, Math.PI), MyType(0, 0, 0)) testImplicit(myTypes.head) testImplicit(myTypes) testImplicit(myTypes.toSeq) testImplicit(myTypes.toSeq)(newProductSeqEncoder[MyType]) - testImplicit(immutable.ArraySeq.unsafeWrapArray(myTypes)) + testImplicit(myTypes.toImmutableArraySeq) // Others. val decimal = java.math.BigDecimal.valueOf(3141527000000000000L, 18) testImplicit(decimal) - testImplicit(immutable.ArraySeq.unsafeWrapArray(Array(decimal))) + testImplicit(Array(decimal).toImmutableArraySeq) testImplicit(BigDecimal(decimal)) testImplicit(Date.valueOf(LocalDate.now())) testImplicit(LocalDate.now()) diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 513fc30a584ad..ac9619487f02c 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.types.Decimal */ object ArrowDeserializers { import ArrowEncoderUtils._ + import org.apache.spark.util.ArrayImplicits._ /** * Create an Iterator of `T`. This iterator takes an Iterator of Arrow IPC Streams, and @@ -227,7 +228,7 @@ object ArrowDeserializers { new VectorFieldDeserializer[immutable.ArraySeq[Any], ListVector](v) { def value(i: Int): immutable.ArraySeq[Any] = { val array = getArray(vector, i, deserializer)(element.clsTag) - ScalaCollectionUtils.toImmutableArraySeq(array) + array.asInstanceOf[Array[_]].toImmutableArraySeq } } } else if (isSubClass(Classes.ITERABLE, tag)) { diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala index 16157ded5e2c1..8bc4c0435d0d3 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ScalaCollectionUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql.connect.client.arrow -import scala.collection.{immutable, mutable, IterableFactory, MapFactory} +import scala.collection.{mutable, IterableFactory, MapFactory} import scala.reflect.ClassTag import org.apache.spark.sql.connect.client.arrow.ArrowDeserializers.resolveCompanion @@ -34,8 +34,4 @@ private[arrow] object ScalaCollectionUtils { def wrap[T](array: AnyRef): mutable.ArraySeq[T] = { mutable.ArraySeq.make(array.asInstanceOf[Array[T]]) } - - def toImmutableArraySeq[T](array: AnyRef): immutable.ArraySeq[T] = { - immutable.ArraySeq.unsafeWrapArray(array.asInstanceOf[Array[T]]) - } } From 718ef817b575d5c5627df0dbf9904aabadabc734 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 12 Jan 2024 13:47:24 +0800 Subject: [PATCH 3/3] [SPARK-46615][CONNECT] Support s.c.immutable.ArraySeq in ArrowDeserializers --- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7952e0f174b74..2bd649ea85e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -840,7 +840,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { Row(ArrayBuffer(100))) val myUdf2 = udf((a: immutable.ArraySeq[Int]) => - immutable.ArraySeq.unsafeWrapArray[Int]((a :+ 5 :+ 6).toArray)) + immutable.ArraySeq.unsafeWrapArray[Int](a.appended(5).appended(6).toArray)) checkAnswer(Seq(Array(1, 2, 3)) .toDF("col") .select(myUdf2(Column("col"))),