From ab4c4555b9f21df5356ace1b5f0e8b6455cae2de Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Sun, 7 Apr 2024 15:51:08 +0200 Subject: [PATCH] added encoding for DatePeriod, DateTimePeriod, Instant, LocalDateTime, and LocalDate, Duration not working --- .../jetbrains/kotlinx/spark/api/Encoding.kt | 32 ++++++--- .../kotlinx/spark/api/udts/DatePeriodUdt.kt | 27 +++++++ .../spark/api/udts/DateTimePeriodUdt.kt | 46 ++++++++++++ .../kotlinx/spark/api/udts/DurationUdt.kt | 46 ++++++++++++ .../kotlinx/spark/api/udts/InstantUdt.kt | 26 +++++++ .../spark/api/udts/LocalDateTimeUdt.kt | 26 +++++++ .../kotlinx/spark/api/udts/LocalDateUdt.kt | 26 +++++++ .../kotlinx/spark/api/EncodingTest.kt | 70 +++++++++++++++++++ 8 files changed, 288 insertions(+), 11 deletions(-) create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DatePeriodUdt.kt create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DateTimePeriodUdt.kt create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DurationUdt.kt create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/InstantUdt.kt create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateTimeUdt.kt create mode 100644 kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateUdt.kt diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt index ff48b59f..874b5e66 100644 --- a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt @@ -49,6 +49,11 @@ import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.unsafe.types.CalendarInterval import org.jetbrains.kotlinx.spark.api.plugin.annotations.ColumnName import org.jetbrains.kotlinx.spark.api.plugin.annotations.Sparkify +import org.jetbrains.kotlinx.spark.api.udts.DatePeriodUdt +import org.jetbrains.kotlinx.spark.api.udts.DateTimePeriodUdt +import org.jetbrains.kotlinx.spark.api.udts.InstantUdt +import org.jetbrains.kotlinx.spark.api.udts.LocalDateTimeUdt +import org.jetbrains.kotlinx.spark.api.udts.LocalDateUdt import scala.reflect.ClassTag import java.io.Serializable import java.util.* @@ -170,12 +175,14 @@ object KotlinTypeInference : Serializable { * @return an [AgnosticEncoder] for the given [kType]. */ @Suppress("UNCHECKED_CAST") - fun encoderFor(kType: KType): AgnosticEncoder = - encoderFor( + fun encoderFor(kType: KType): AgnosticEncoder { + registerUdts() + return encoderFor( currentType = kType, seenTypeSet = emptySet(), typeVariables = emptyMap(), ) as AgnosticEncoder + } private inline fun KType.isSubtypeOf(): Boolean = isSubtypeOf(typeOf()) @@ -296,6 +303,16 @@ object KotlinTypeInference : Serializable { private fun transitiveMerge(a: Map, b: Map, valueToKey: (V) -> K?): Map = a + b.mapValues { a.getOrDefault(valueToKey(it.value), it.value) } + private fun registerUdts() { + UDTRegistration.register(kotlinx.datetime.LocalDate::class.java.name, LocalDateUdt::class.java.name) + UDTRegistration.register(kotlinx.datetime.Instant::class.java.name, InstantUdt::class.java.name) + UDTRegistration.register(kotlinx.datetime.LocalDateTime::class.java.name, LocalDateTimeUdt::class.java.name) + UDTRegistration.register(kotlinx.datetime.DatePeriod::class.java.name, DatePeriodUdt::class.java.name) + UDTRegistration.register(kotlinx.datetime.DateTimePeriod::class.java.name, DateTimePeriodUdt::class.java.name) + // TODO + // UDTRegistration.register(kotlin.time.Duration::class.java.name, DurationUdt::class.java.name) + } + /** * */ @@ -375,19 +392,12 @@ object KotlinTypeInference : Serializable { currentType.isSubtypeOf() -> AgnosticEncoders.`JavaBigIntEncoder$`.`MODULE$` currentType.isSubtypeOf() -> AgnosticEncoders.`CalendarIntervalEncoder$`.`MODULE$` currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER() - currentType.isSubtypeOf() -> TODO("User java.time.LocalDate for now. We'll create a UDT for this.") currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_DATE_ENCODER() currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_INSTANT_ENCODER() - currentType.isSubtypeOf() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.") - currentType.isSubtypeOf() -> TODO("Use java.time.Instant for now. We'll create a UDT for this.") currentType.isSubtypeOf() -> AgnosticEncoders.STRICT_TIMESTAMP_ENCODER() currentType.isSubtypeOf() -> AgnosticEncoders.`LocalDateTimeEncoder$`.`MODULE$` - currentType.isSubtypeOf() -> TODO("Use java.time.LocalDateTime for now. We'll create a UDT for this.") currentType.isSubtypeOf() -> AgnosticEncoders.`DayTimeIntervalEncoder$`.`MODULE$` - currentType.isSubtypeOf() -> TODO("Use java.time.Duration for now. We'll create a UDT for this.") currentType.isSubtypeOf() -> AgnosticEncoders.`YearMonthIntervalEncoder$`.`MODULE$` - currentType.isSubtypeOf() -> TODO("Use java.time.Period for now. We'll create a UDT for this.") - currentType.isSubtypeOf() -> TODO("Use java.time.Period for now. We'll create a UDT for this.") currentType.isSubtypeOf() -> AgnosticEncoders.`UnboundRowEncoder$`.`MODULE$` // enums @@ -414,6 +424,8 @@ object KotlinTypeInference : Serializable { AgnosticEncoders.UDTEncoder(udt, udt.javaClass) } + currentType.isSubtypeOf() -> TODO("kotlin.time.Duration is unsupported. Use java.time.Duration for now.") + currentType.isSubtypeOf?>() -> { val elementEncoder = encoderFor( currentType = tArguments.first().type!!, @@ -666,8 +678,6 @@ object KotlinTypeInference : Serializable { fields.asScalaSeq(), ) } - -// else -> throw IllegalArgumentException("No encoder found for type $currentType") } } diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DatePeriodUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DatePeriodUdt.kt new file mode 100644 index 00000000..3705cb5a --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DatePeriodUdt.kt @@ -0,0 +1,27 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import kotlinx.datetime.DatePeriod +import kotlinx.datetime.toJavaPeriod +import kotlinx.datetime.toKotlinDatePeriod +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.types.YearMonthIntervalType + +/** + * NOTE: Just like java.time.DatePeriod, this is truncated to months. + */ +class DatePeriodUdt : UserDefinedType() { + + override fun userClass(): Class = DatePeriod::class.java + override fun deserialize(datum: Any?): DatePeriod? = + when (datum) { + null -> null + is Int -> IntervalUtils.monthsToPeriod(datum).toKotlinDatePeriod() + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + + override fun serialize(obj: DatePeriod?): Int? = + obj?.let { IntervalUtils.periodToMonths(it.toJavaPeriod()) } + + override fun sqlType(): YearMonthIntervalType = YearMonthIntervalType.apply() +} \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DateTimePeriodUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DateTimePeriodUdt.kt new file mode 100644 index 00000000..3b939cf9 --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DateTimePeriodUdt.kt @@ -0,0 +1,46 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import kotlinx.datetime.DateTimePeriod +import org.apache.spark.sql.types.CalendarIntervalType +import org.apache.spark.sql.types.`CalendarIntervalType$` +import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.unsafe.types.CalendarInterval +import kotlin.time.Duration.Companion.hours +import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.nanoseconds +import kotlin.time.Duration.Companion.seconds + +/** + * NOTE: Just like java.time.DatePeriod, this is truncated to months. + */ +class DateTimePeriodUdt : UserDefinedType() { + + override fun userClass(): Class = DateTimePeriod::class.java + override fun deserialize(datum: Any?): DateTimePeriod? = + when (datum) { + null -> null + is CalendarInterval -> + DateTimePeriod( + months = datum.months, + days = datum.days, + nanoseconds = datum.microseconds * 1_000, + ) + + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + + override fun serialize(obj: DateTimePeriod?): CalendarInterval? = + obj?.let { + CalendarInterval( + /* months = */ obj.months + obj.years * 12, + /* days = */ obj.days, + /* microseconds = */ + (obj.hours.hours + + obj.minutes.minutes + + obj.seconds.seconds + + obj.nanoseconds.nanoseconds).inWholeMicroseconds, + ) + } + + override fun sqlType(): CalendarIntervalType = `CalendarIntervalType$`.`MODULE$` +} \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DurationUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DurationUdt.kt new file mode 100644 index 00000000..ff1e5df4 --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/DurationUdt.kt @@ -0,0 +1,46 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import org.apache.spark.sql.catalyst.util.IntervalUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.DayTimeIntervalType +import org.apache.spark.sql.types.UserDefinedType +import kotlin.time.Duration +import kotlin.time.Duration.Companion.milliseconds +import kotlin.time.Duration.Companion.nanoseconds +import kotlin.time.toJavaDuration +import kotlin.time.toKotlinDuration + +// TODO Fails, likely because Duration is a value class. +class DurationUdt : UserDefinedType() { + + override fun userClass(): Class = Duration::class.java + override fun deserialize(datum: Any?): Duration? = + when (datum) { + null -> null + is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration() +// is Long -> IntervalUtils.microsToDuration(datum).toKotlinDuration().let { +// // store in nanos +// it.inWholeNanoseconds shl 1 +// } + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + +// override fun serialize(obj: Duration): Long = +// IntervalUtils.durationToMicros(obj.toJavaDuration()) + + fun serialize(obj: Long): Long? = + obj?.let { rawValue -> + val unitDiscriminator = rawValue.toInt() and 1 + fun isInNanos() = unitDiscriminator == 0 + val value = rawValue shr 1 + val duration = if (isInNanos()) value.nanoseconds else value.milliseconds + + IntervalUtils.durationToMicros(duration.toJavaDuration()) + } + + override fun serialize(obj: Duration): Long? = + obj?.let { IntervalUtils.durationToMicros(it.toJavaDuration()) } + + + override fun sqlType(): DataType = DayTimeIntervalType.apply() +} \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/InstantUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/InstantUdt.kt new file mode 100644 index 00000000..7b8ba110 --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/InstantUdt.kt @@ -0,0 +1,26 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import kotlinx.datetime.Instant +import kotlinx.datetime.toJavaInstant +import kotlinx.datetime.toKotlinInstant +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.`TimestampType$` +import org.apache.spark.sql.types.UserDefinedType + + +class InstantUdt : UserDefinedType() { + + override fun userClass(): Class = Instant::class.java + override fun deserialize(datum: Any?): Instant? = + when (datum) { + null -> null + is Long -> DateTimeUtils.microsToInstant(datum).toKotlinInstant() + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + + override fun serialize(obj: Instant?): Long? = + obj?.let { DateTimeUtils.instantToMicros(it.toJavaInstant()) } + + override fun sqlType(): DataType = `TimestampType$`.`MODULE$` +} \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateTimeUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateTimeUdt.kt new file mode 100644 index 00000000..7dd4fa0d --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateTimeUdt.kt @@ -0,0 +1,26 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import kotlinx.datetime.LocalDateTime +import kotlinx.datetime.toJavaLocalDateTime +import kotlinx.datetime.toKotlinLocalDateTime +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.`TimestampNTZType$` +import org.apache.spark.sql.types.UserDefinedType + + +class LocalDateTimeUdt : UserDefinedType() { + + override fun userClass(): Class = LocalDateTime::class.java + override fun deserialize(datum: Any?): LocalDateTime? = + when (datum) { + null -> null + is Long -> DateTimeUtils.microsToLocalDateTime(datum).toKotlinLocalDateTime() + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + + override fun serialize(obj: LocalDateTime?): Long? = + obj?.let { DateTimeUtils.localDateTimeToMicros(it.toJavaLocalDateTime()) } + + override fun sqlType(): DataType = `TimestampNTZType$`.`MODULE$` +} \ No newline at end of file diff --git a/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateUdt.kt b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateUdt.kt new file mode 100644 index 00000000..033b05e5 --- /dev/null +++ b/kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/udts/LocalDateUdt.kt @@ -0,0 +1,26 @@ +package org.jetbrains.kotlinx.spark.api.udts + +import kotlinx.datetime.LocalDate +import kotlinx.datetime.toJavaLocalDate +import kotlinx.datetime.toKotlinLocalDate +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.`DateType$` +import org.apache.spark.sql.types.UserDefinedType + + +class LocalDateUdt : UserDefinedType() { + + override fun userClass(): Class = LocalDate::class.java + override fun deserialize(datum: Any?): LocalDate? = + when (datum) { + null -> null + is Int -> DateTimeUtils.daysToLocalDate(datum).toKotlinLocalDate() + else -> throw IllegalArgumentException("Unsupported datum: $datum") + } + + override fun serialize(obj: LocalDate?): Int? = + obj?.let { DateTimeUtils.localDateToDays(it.toJavaLocalDate()) } + + override fun sqlType(): DataType = `DateType$`.`MODULE$` +} \ No newline at end of file diff --git a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt index 295faa19..05acc6d0 100644 --- a/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt +++ b/kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/EncodingTest.kt @@ -25,6 +25,11 @@ import io.kotest.core.spec.style.ShouldSpec import io.kotest.matchers.collections.shouldContainExactly import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldContain +import kotlinx.datetime.DateTimePeriod +import kotlinx.datetime.toKotlinDatePeriod +import kotlinx.datetime.toKotlinInstant +import kotlinx.datetime.toKotlinLocalDate +import kotlinx.datetime.toKotlinLocalDateTime import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.Decimal import org.apache.spark.unsafe.types.CalendarInterval @@ -37,7 +42,12 @@ import java.sql.Timestamp import java.time.Duration import java.time.Instant import java.time.LocalDate +import java.time.LocalDateTime import java.time.Period +import kotlin.time.TimeMark +import kotlin.time.TimeSource +import kotlin.time.TimeSource.Monotonic +import kotlin.time.toKotlinDuration class EncodingTest : ShouldSpec({ @@ -53,6 +63,12 @@ class EncodingTest : ShouldSpec({ dataset.collectAsList() shouldBe dates } + should("handle Kotlinx LocalDate Datasets") { + val dates = listOf(LocalDate.now().toKotlinLocalDate(), LocalDate.now().toKotlinLocalDate()) + val dataset = dates.toDS() + dataset.collectAsList() shouldBe dates + } + should("handle Instant Datasets") { val instants = listOf(Instant.now(), Instant.now()) val dataset: Dataset = instants.toDS() @@ -63,17 +79,44 @@ class EncodingTest : ShouldSpec({ } } + should("handle Kotlinx Instant Datasets") { + val instants = listOf(Instant.now().toKotlinInstant(), Instant.now().toKotlinInstant()) + val dataset = instants.toDS() + dataset.collectAsList().let { (first, second) -> + val (a, b) = instants + a.compareTo(first) shouldBe 0 + b.compareTo(second) shouldBe 0 + } + } + should("handle Timestamp Datasets") { val timeStamps = listOf(Timestamp(0L), Timestamp(1L)) val dataset = timeStamps.toDS() dataset.collectAsList() shouldBe timeStamps } + should("handle LocalDateTime") { + val timeStamps = listOf(LocalDateTime.now(), LocalDateTime.now().plusDays(3)) + val dataset = timeStamps.toDS() + dataset.collectAsList() shouldBe timeStamps + } + + should("handle Kotlinx LocalDateTime") { + val timeStamps = listOf(LocalDateTime.now().toKotlinLocalDateTime(), LocalDateTime.now().plusDays(3).toKotlinLocalDateTime()) + val dataset = timeStamps.toDS() + dataset.collectAsList() shouldBe timeStamps + } + //#if sparkMinor >= 3.2 should("handle Duration Datasets") { val dataset = dsOf(Duration.ZERO) dataset.collectAsList() shouldBe listOf(Duration.ZERO) } + + xshould("handle Kotlin Duration Datasets") { + val dataset = dsOf(Duration.ZERO.toKotlinDuration()) + dataset.collectAsList() shouldBe listOf(Duration.ZERO.toKotlinDuration()) + } //#endif //#if sparkMinor >= 3.2 @@ -92,6 +135,33 @@ class EncodingTest : ShouldSpec({ } //#endif + should("handle Kotlinx DateTimePeriod Datasets") { + val periods = listOf(DateTimePeriod(years = 1), DateTimePeriod(hours = 2)) + val dataset = periods.toDS() + + dataset.show(false) + + dataset.collectAsList().let { + it[0] shouldBe DateTimePeriod(years = 1) + // NOTE Spark truncates java.time.Period to months. + it[1] shouldBe DateTimePeriod(hours = 2) + } + } + + should("handle Kotlinx DatePeriod Datasets") { + val periods = listOf(Period.ZERO.toKotlinDatePeriod(), Period.ofDays(2).toKotlinDatePeriod()) + val dataset = periods.toDS() + + dataset.show(false) + + dataset.collectAsList().let { + it[0] shouldBe Period.ZERO.toKotlinDatePeriod() + + // NOTE Spark truncates java.time.Period to months. + it[1] shouldBe Period.ofDays(0).toKotlinDatePeriod() + } + } + should("handle binary datasets") { val byteArray = "Hello there".encodeToByteArray() val dataset = dsOf(byteArray)