Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.vectorized;

import scala.PartialFunction;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Decimal;
Expand Down Expand Up @@ -336,10 +338,21 @@ public final VariantVal getVariant(int rowId) {
* Sets up the data type of this column vector.
*/
protected ColumnVector(DataType type) {
if (type instanceof UserDefinedType) {
this.type = ((UserDefinedType) type).sqlType();
} else {
this.type = type;
}
this.type = type.transformRecursively(
new PartialFunction<DataType, DataType>() {
@Override
public boolean isDefinedAt(DataType x) {
return x instanceof UserDefinedType<?>;
}

@Override
public DataType apply(DataType t) {
if (t instanceof UserDefinedType<?> udt) {
return udt.sqlType();
} else {
return t;
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ class ParquetToSparkSchemaConverter(
field: ColumnIO,
sparkReadType: Option[DataType] = None): ParquetColumn = {
val targetType = sparkReadType.map {
case udt: UserDefinedType[_] => udt.sqlType
case otherType => otherType
_.transformRecursively { case t: UserDefinedType[_] => t.sqlType }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. What about ORC file format?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dongjoon-hyun, good question. I have an umbrella ticket for udt improvements. Let me check other formats or readers with followups if necessary

}
field match {
case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.vectorized

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.YearUDT
import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.execution.columnar.{ColumnAccessor, ColumnDictionary}
Expand Down Expand Up @@ -926,5 +927,27 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper {
}
}
}
}

val yearUDT = new YearUDT
testVectors("user defined type", 10, yearUDT) { testVector =>
assert(testVector.dataType() === IntegerType)
(0 until 10).foreach { i =>
testVector.appendInt(i)
}
}

testVectors("user defined type in map type",
10, MapType(IntegerType, yearUDT)) { testVector =>
assert(testVector.dataType() === MapType(IntegerType, IntegerType))
}

testVectors("user defined type in array type",
10, ArrayType(yearUDT, containsNull = true)) { testVector =>
assert(testVector.dataType() === ArrayType(IntegerType, containsNull = true))
}

testVectors("user defined type in struct type",
10, StructType(Seq(StructField("year", yearUDT)))) { testVector =>
assert(testVector.dataType() === StructType(Seq(StructField("year", IntegerType))))
}
}