Skip to content

Add read of Arrow TimeStamp without timezone as LocalDatetime #515 #516

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import org.apache.arrow.vector.TimeMicroVector
import org.apache.arrow.vector.TimeMilliVector
import org.apache.arrow.vector.TimeNanoVector
import org.apache.arrow.vector.TimeSecVector
import org.apache.arrow.vector.TimeStampMicroVector
import org.apache.arrow.vector.TimeStampMilliVector
import org.apache.arrow.vector.TimeStampNanoVector
import org.apache.arrow.vector.TimeStampSecVector
import org.apache.arrow.vector.TinyIntVector
import org.apache.arrow.vector.UInt1Vector
import org.apache.arrow.vector.UInt2Vector
Expand Down Expand Up @@ -130,6 +134,39 @@ private fun TimeMilliVector.values(range: IntRange): List<LocalTime?> = range.ma

private fun TimeSecVector.values(range: IntRange): List<LocalTime?> =
range.map { getObject(it)?.let { LocalTime.ofSecondOfDay(it.toLong()) } }

private fun TimeStampNanoVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
if (isNull(i)) {
null
} else {
getObject(it)
}
}

private fun TimeStampMicroVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
if (isNull(i)) {
null
} else {
getObject(it)
}
}

private fun TimeStampMilliVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
if (isNull(i)) {
null
} else {
getObject(it)
}
}

private fun TimeStampSecVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
if (isNull(i)) {
null
} else {
getObject(it)
}
}

private fun StructVector.values(range: IntRange): List<Map<String, Any?>?> = range.map {
getObject(it)
}
Expand Down Expand Up @@ -202,6 +239,10 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeStampNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeStampMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeStampMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
else -> {
throw NotImplementedError("reading from ${vector.javaClass.canonicalName} is not implemented")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@ package org.jetbrains.kotlinx.dataframe.io
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.shouldBe
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector.TimeStampMicroVector
import org.apache.arrow.vector.TimeStampMilliVector
import org.apache.arrow.vector.TimeStampNanoVector
import org.apache.arrow.vector.TimeStampSecVector
import org.apache.arrow.vector.VectorSchemaRoot
import org.apache.arrow.vector.ipc.ArrowFileWriter
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.arrow.vector.types.FloatingPointPrecision
import org.apache.arrow.vector.types.TimeUnit
import org.apache.arrow.vector.types.pojo.ArrowType
import org.apache.arrow.vector.types.pojo.Field
import org.apache.arrow.vector.types.pojo.FieldType
Expand All @@ -23,10 +32,13 @@ import org.jetbrains.kotlinx.dataframe.api.remove
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
import org.junit.Test
import java.io.ByteArrayOutputStream
import java.io.File
import java.net.URL
import java.nio.channels.Channels
import java.time.LocalDate
import java.time.LocalDateTime
import java.time.ZoneOffset
import java.util.Locale
import kotlin.reflect.typeOf

Expand Down Expand Up @@ -459,4 +471,86 @@ internal class ArrowKtTest {
val data = dataFrame.saveArrowFeatherToByteArray()
DataFrame.readArrowFeather(data) shouldBe dataFrame
}

@Test
fun testTimeStamp(){
val dates = listOf(
LocalDateTime.of(2023, 11, 23, 9, 30, 25),
LocalDateTime.of(2015, 5, 25, 14, 20, 13),
LocalDateTime.of(2013, 6, 19, 11, 20, 13)
)

val dataFrame = dataFrameOf(
"ts_nano" to dates,
"ts_micro" to dates,
"ts_milli" to dates,
"ts_sec" to dates
)

DataFrame.readArrowFeather(writeArrowTimestamp(dates)) shouldBe dataFrame
DataFrame.readArrowIPC(writeArrowTimestamp(dates, true)) shouldBe dataFrame
}

private fun writeArrowTimestamp(dates: List<LocalDateTime>, streaming: Boolean = false): ByteArray {
RootAllocator().use { allocator ->
val timeStampMilli = Field(
"ts_milli",
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
null
)

val timeStampMicro = Field(
"ts_micro",
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MICROSECOND, null)),
null
)

val timeStampNano = Field(
"ts_nano",
FieldType.nullable(ArrowType.Timestamp(TimeUnit.NANOSECOND, null)),
null
)

val timeStampSec = Field(
"ts_sec",
FieldType.nullable(ArrowType.Timestamp(TimeUnit.SECOND, null)),
null
)
val schemaTimeStamp = Schema(
listOf(timeStampNano, timeStampMicro, timeStampMilli, timeStampSec)
)
VectorSchemaRoot.create(schemaTimeStamp, allocator).use { vectorSchemaRoot ->
val timeStampMilliVector = vectorSchemaRoot.getVector("ts_milli") as TimeStampMilliVector
val timeStampNanoVector = vectorSchemaRoot.getVector("ts_nano") as TimeStampNanoVector
val timeStampMicroVector = vectorSchemaRoot.getVector("ts_micro") as TimeStampMicroVector
val timeStampSecVector = vectorSchemaRoot.getVector("ts_sec") as TimeStampSecVector
timeStampMilliVector.allocateNew(dates.size)
timeStampNanoVector.allocateNew(dates.size)
timeStampMicroVector.allocateNew(dates.size)
timeStampSecVector.allocateNew(dates.size)

dates.forEachIndexed { index, localDateTime ->
val instant = localDateTime.toInstant(ZoneOffset.UTC)
timeStampNanoVector[index] = instant.toEpochMilli() * 1_000_000L + instant.nano
timeStampMicroVector[index] = instant.toEpochMilli() * 1_000L
timeStampMilliVector[index] = instant.toEpochMilli()
timeStampSecVector[index] = instant.toEpochMilli() / 1_000L
}
vectorSchemaRoot.setRowCount(dates.size)
val bos = ByteArrayOutputStream()
bos.use { out ->
val arrowWriter = if (streaming) {
ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(out))
} else {
ArrowFileWriter(vectorSchemaRoot, null, Channels.newChannel(out))
}
arrowWriter.use { writer ->
writer.start()
writer.writeBatch()
}
}
return bos.toByteArray()
}
}
}
}