From f026381feee879b75698696e952582902e332b0e Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 2 Jul 2025 19:04:31 +0800 Subject: [PATCH] [SPARK-52651][SQL] Handle User Defined Type in Nested Column Vectors --- .../spark/sql/vectorized/ColumnVector.java | 23 +++++++++++++---- .../parquet/ParquetSchemaConverter.scala | 3 +-- .../vectorized/ColumnVectorSuite.scala | 25 ++++++++++++++++++- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 54b62c00283f..f1d1f5b3ea80 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.vectorized; +import scala.PartialFunction; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -336,10 +338,21 @@ public final VariantVal getVariant(int rowId) { * Sets up the data type of this column vector. */ protected ColumnVector(DataType type) { - if (type instanceof UserDefinedType) { - this.type = ((UserDefinedType) type).sqlType(); - } else { - this.type = type; - } + this.type = type.transformRecursively( + new PartialFunction() { + @Override + public boolean isDefinedAt(DataType x) { + return x instanceof UserDefinedType; + } + + @Override + public DataType apply(DataType t) { + if (t instanceof UserDefinedType udt) { + return udt.sqlType(); + } else { + return t; + } + } + }); } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index e05d5fe2fd88..16bd776bea0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -196,8 +196,7 @@ class ParquetToSparkSchemaConverter( field: ColumnIO, sparkReadType: Option[DataType] = None): ParquetColumn = { val targetType = sparkReadType.map { - case udt: UserDefinedType[_] => udt.sqlType - case otherType => otherType + _.transformRecursively { case t: UserDefinedType[_] => t.sqlType } } field match { case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 0edbfd10d8cd..a0fe44b96e7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.vectorized import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.YearUDT import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.execution.columnar.{ColumnAccessor, ColumnDictionary} @@ -926,5 +927,27 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { } } } -} + val yearUDT = new YearUDT + testVectors("user defined type", 10, yearUDT) { testVector => + assert(testVector.dataType() === IntegerType) + (0 until 10).foreach { i => + testVector.appendInt(i) + } + } + + testVectors("user defined type in map type", + 10, MapType(IntegerType, yearUDT)) { testVector => + assert(testVector.dataType() === MapType(IntegerType, IntegerType)) + } + + testVectors("user defined type in array type", + 10, ArrayType(yearUDT, containsNull = true)) { testVector => + assert(testVector.dataType() === ArrayType(IntegerType, containsNull = true)) + } + + testVectors("user defined type in struct type", + 10, StructType(Seq(StructField("year", yearUDT)))) { testVector => + assert(testVector.dataType() === StructType(Seq(StructField("year", IntegerType)))) + } +}