From e36b006f82985e8b2f6d9567d2b883691a086a77 Mon Sep 17 00:00:00 2001 From: Florian Bernard Date: Fri, 24 Nov 2023 15:45:18 +0100 Subject: [PATCH] Add read of Arrow TimeStamp without timezone as LocalDatetime #515 --- .../kotlinx/dataframe/io/arrowReadingImpl.kt | 41 ++++++++ .../kotlinx/dataframe/io/ArrowKtTest.kt | 94 +++++++++++++++++++ 2 files changed, 135 insertions(+) diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt index a9aacee7e5..5c09295179 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt @@ -18,6 +18,10 @@ import org.apache.arrow.vector.TimeMicroVector import org.apache.arrow.vector.TimeMilliVector import org.apache.arrow.vector.TimeNanoVector import org.apache.arrow.vector.TimeSecVector +import org.apache.arrow.vector.TimeStampMicroVector +import org.apache.arrow.vector.TimeStampMilliVector +import org.apache.arrow.vector.TimeStampNanoVector +import org.apache.arrow.vector.TimeStampSecVector import org.apache.arrow.vector.TinyIntVector import org.apache.arrow.vector.UInt1Vector import org.apache.arrow.vector.UInt2Vector @@ -130,6 +134,39 @@ private fun TimeMilliVector.values(range: IntRange): List = range.ma private fun TimeSecVector.values(range: IntRange): List = range.map { getObject(it)?.let { LocalTime.ofSecondOfDay(it.toLong()) } } + +private fun TimeStampNanoVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + getObject(it) + } +} + +private fun TimeStampMicroVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + getObject(it) + } +} + +private fun TimeStampMilliVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + getObject(it) + } +} + +private fun TimeStampSecVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + getObject(it) + } +} + private fun StructVector.values(range: IntRange): List?> = range.map { getObject(it) } @@ -202,6 +239,10 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeStampNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeStampMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeStampMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) else -> { throw NotImplementedError("reading from ${vector.javaClass.canonicalName} is not implemented") diff --git a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt index 9aeaf5a7ce..500780ec54 100644 --- a/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt +++ b/dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt @@ -3,7 +3,16 @@ package org.jetbrains.kotlinx.dataframe.io import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.collections.shouldContain import io.kotest.matchers.shouldBe +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.TimeStampMicroVector +import org.apache.arrow.vector.TimeStampMilliVector +import org.apache.arrow.vector.TimeStampNanoVector +import org.apache.arrow.vector.TimeStampSecVector +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowFileWriter +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.TimeUnit import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.arrow.vector.types.pojo.Field import org.apache.arrow.vector.types.pojo.FieldType @@ -23,10 +32,13 @@ import org.jetbrains.kotlinx.dataframe.api.remove import org.jetbrains.kotlinx.dataframe.api.toColumn import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException import org.junit.Test +import java.io.ByteArrayOutputStream import java.io.File import java.net.URL +import java.nio.channels.Channels import java.time.LocalDate import java.time.LocalDateTime +import java.time.ZoneOffset import java.util.Locale import kotlin.reflect.typeOf @@ -459,4 +471,86 @@ internal class ArrowKtTest { val data = dataFrame.saveArrowFeatherToByteArray() DataFrame.readArrowFeather(data) shouldBe dataFrame } + + @Test + fun testTimeStamp(){ + val dates = listOf( + LocalDateTime.of(2023, 11, 23, 9, 30, 25), + LocalDateTime.of(2015, 5, 25, 14, 20, 13), + LocalDateTime.of(2013, 6, 19, 11, 20, 13) + ) + + val dataFrame = dataFrameOf( + "ts_nano" to dates, + "ts_micro" to dates, + "ts_milli" to dates, + "ts_sec" to dates + ) + + DataFrame.readArrowFeather(writeArrowTimestamp(dates)) shouldBe dataFrame + DataFrame.readArrowIPC(writeArrowTimestamp(dates, true)) shouldBe dataFrame + } + + private fun writeArrowTimestamp(dates: List, streaming: Boolean = false): ByteArray { + RootAllocator().use { allocator -> + val timeStampMilli = Field( + "ts_milli", + FieldType.nullable(ArrowType.Timestamp(TimeUnit.MILLISECOND, null)), + null + ) + + val timeStampMicro = Field( + "ts_micro", + FieldType.nullable(ArrowType.Timestamp(TimeUnit.MICROSECOND, null)), + null + ) + + val timeStampNano = Field( + "ts_nano", + FieldType.nullable(ArrowType.Timestamp(TimeUnit.NANOSECOND, null)), + null + ) + + val timeStampSec = Field( + "ts_sec", + FieldType.nullable(ArrowType.Timestamp(TimeUnit.SECOND, null)), + null + ) + val schemaTimeStamp = Schema( + listOf(timeStampNano, timeStampMicro, timeStampMilli, timeStampSec) + ) + VectorSchemaRoot.create(schemaTimeStamp, allocator).use { vectorSchemaRoot -> + val timeStampMilliVector = vectorSchemaRoot.getVector("ts_milli") as TimeStampMilliVector + val timeStampNanoVector = vectorSchemaRoot.getVector("ts_nano") as TimeStampNanoVector + val timeStampMicroVector = vectorSchemaRoot.getVector("ts_micro") as TimeStampMicroVector + val timeStampSecVector = vectorSchemaRoot.getVector("ts_sec") as TimeStampSecVector + timeStampMilliVector.allocateNew(dates.size) + timeStampNanoVector.allocateNew(dates.size) + timeStampMicroVector.allocateNew(dates.size) + timeStampSecVector.allocateNew(dates.size) + + dates.forEachIndexed { index, localDateTime -> + val instant = localDateTime.toInstant(ZoneOffset.UTC) + timeStampNanoVector[index] = instant.toEpochMilli() * 1_000_000L + instant.nano + timeStampMicroVector[index] = instant.toEpochMilli() * 1_000L + timeStampMilliVector[index] = instant.toEpochMilli() + timeStampSecVector[index] = instant.toEpochMilli() / 1_000L + } + vectorSchemaRoot.setRowCount(dates.size) + val bos = ByteArrayOutputStream() + bos.use { out -> + val arrowWriter = if (streaming) { + ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(out)) + } else { + ArrowFileWriter(vectorSchemaRoot, null, Channels.newChannel(out)) + } + arrowWriter.use { writer -> + writer.start() + writer.writeBatch() + } + } + return bos.toByteArray() + } + } + } }