diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b6ab60a91955..cc52b6d8a14a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.time.{LocalDateTime, ZoneOffset} +import java.time.Year import java.util.Arrays import org.apache.spark.rdd.RDD @@ -33,22 +33,22 @@ private[sql] case class MyLabeledPoint(label: Double, features: TestUDT.MyDenseV def getFeatures: TestUDT.MyDenseVector = features } -private[sql] case class FooWithDate(date: LocalDateTime, s: String, i: Int) +private[sql] case class FooWithDate(year: Year, s: String, i: Int) -private[sql] class LocalDateTimeUDT extends UserDefinedType[LocalDateTime] { - override def sqlType: DataType = LongType +private[sql] class YearUDT extends UserDefinedType[Year] { + override def sqlType: DataType = IntegerType - override def serialize(obj: LocalDateTime): Long = { - obj.toEpochSecond(ZoneOffset.UTC) + override def serialize(obj: Year): Int = { + obj.getValue } - def deserialize(datum: Any): LocalDateTime = datum match { - case value: Long => LocalDateTime.ofEpochSecond(value, 0, ZoneOffset.UTC) + def deserialize(datum: Any): Year = datum match { + case value: Int => Year.of(value) } - override def userClass: Class[LocalDateTime] = classOf[LocalDateTime] + override def userClass: Class[Year] = classOf[Year] - private[spark] override def asNullable: LocalDateTimeUDT = this + private[spark] override def asNullable: YearUDT = this } class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with ParquetTest @@ -258,19 +258,17 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque test("SPARK-30993: UserDefinedType matched to fixed length SQL type shouldn't be corrupted") { def concatFoo(a: FooWithDate, b: FooWithDate): FooWithDate = { - FooWithDate(b.date, a.s + b.s, a.i) + FooWithDate(b.year, a.s + b.s, a.i) } - UDTRegistration.register(classOf[LocalDateTime].getName, classOf[LocalDateTimeUDT].getName) + UDTRegistration.register(classOf[Year].getName, classOf[YearUDT].getName) - // remove sub-millisecond part as we only use millis based timestamp while serde - val date = LocalDateTime.ofEpochSecond(LocalDateTime.now().toEpochSecond(ZoneOffset.UTC), - 0, ZoneOffset.UTC) - val inputDS = List(FooWithDate(date, "Foo", 1), FooWithDate(date, "Foo", 3), - FooWithDate(date, "Foo", 3)).toDS() + val year = Year.now() + val inputDS = List(FooWithDate(year, "Foo", 1), FooWithDate(year, "Foo", 3), + FooWithDate(year, "Foo", 3)).toDS() val agg = inputDS.groupByKey(x => x.i).mapGroups((_, iter) => iter.reduce(concatFoo)) val result = agg.collect() - assert(result.toSet === Set(FooWithDate(date, "FooFoo", 3), FooWithDate(date, "Foo", 1))) + assert(result.toSet === Set(FooWithDate(year, "FooFoo", 3), FooWithDate(year, "Foo", 1))) } }