@@ -3,7 +3,16 @@ package org.jetbrains.kotlinx.dataframe.io
33import io.kotest.assertions.throwables.shouldThrow
44import io.kotest.matchers.collections.shouldContain
55import io.kotest.matchers.shouldBe
6+ import org.apache.arrow.memory.RootAllocator
7+ import org.apache.arrow.vector.TimeStampMicroVector
8+ import org.apache.arrow.vector.TimeStampMilliVector
9+ import org.apache.arrow.vector.TimeStampNanoVector
10+ import org.apache.arrow.vector.TimeStampSecVector
11+ import org.apache.arrow.vector.VectorSchemaRoot
12+ import org.apache.arrow.vector.ipc.ArrowFileWriter
13+ import org.apache.arrow.vector.ipc.ArrowStreamWriter
614import org.apache.arrow.vector.types.FloatingPointPrecision
15+ import org.apache.arrow.vector.types.TimeUnit
716import org.apache.arrow.vector.types.pojo.ArrowType
817import org.apache.arrow.vector.types.pojo.Field
918import org.apache.arrow.vector.types.pojo.FieldType
@@ -23,10 +32,13 @@ import org.jetbrains.kotlinx.dataframe.api.remove
2332import org.jetbrains.kotlinx.dataframe.api.toColumn
2433import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
2534import org.junit.Test
35+ import java.io.ByteArrayOutputStream
2636import java.io.File
2737import java.net.URL
38+ import java.nio.channels.Channels
2839import java.time.LocalDate
2940import java.time.LocalDateTime
41+ import java.time.ZoneOffset
3042import java.util.Locale
3143import kotlin.reflect.typeOf
3244
@@ -459,4 +471,86 @@ internal class ArrowKtTest {
459471 val data = dataFrame.saveArrowFeatherToByteArray()
460472 DataFrame .readArrowFeather(data) shouldBe dataFrame
461473 }
474+
475+ @Test
476+ fun testTimeStamp (){
477+ val dates = listOf (
478+ LocalDateTime .of(2023 , 11 , 23 , 9 , 30 , 25 ),
479+ LocalDateTime .of(2015 , 5 , 25 , 14 , 20 , 13 ),
480+ LocalDateTime .of(2013 , 6 , 19 , 11 , 20 , 13 )
481+ )
482+
483+ val dataFrame = dataFrameOf(
484+ " ts_nano" to dates,
485+ " ts_micro" to dates,
486+ " ts_milli" to dates,
487+ " ts_sec" to dates
488+ )
489+
490+ DataFrame .readArrowFeather(writeArrowTimestamp(dates)) shouldBe dataFrame
491+ DataFrame .readArrowIPC(writeArrowTimestamp(dates, true )) shouldBe dataFrame
492+ }
493+
494+ private fun writeArrowTimestamp (dates : List <LocalDateTime >, streaming : Boolean = false): ByteArray {
495+ RootAllocator ().use { allocator ->
496+ val timeStampMilli = Field (
497+ " ts_milli" ,
498+ FieldType .nullable(ArrowType .Timestamp (TimeUnit .MILLISECOND , null )),
499+ null
500+ )
501+
502+ val timeStampMicro = Field (
503+ " ts_micro" ,
504+ FieldType .nullable(ArrowType .Timestamp (TimeUnit .MICROSECOND , null )),
505+ null
506+ )
507+
508+ val timeStampNano = Field (
509+ " ts_nano" ,
510+ FieldType .nullable(ArrowType .Timestamp (TimeUnit .NANOSECOND , null )),
511+ null
512+ )
513+
514+ val timeStampSec = Field (
515+ " ts_sec" ,
516+ FieldType .nullable(ArrowType .Timestamp (TimeUnit .SECOND , null )),
517+ null
518+ )
519+ val schemaTimeStamp = Schema (
520+ listOf (timeStampNano, timeStampMicro, timeStampMilli, timeStampSec)
521+ )
522+ VectorSchemaRoot .create(schemaTimeStamp, allocator).use { vectorSchemaRoot ->
523+ val timeStampMilliVector = vectorSchemaRoot.getVector(" ts_milli" ) as TimeStampMilliVector
524+ val timeStampNanoVector = vectorSchemaRoot.getVector(" ts_nano" ) as TimeStampNanoVector
525+ val timeStampMicroVector = vectorSchemaRoot.getVector(" ts_micro" ) as TimeStampMicroVector
526+ val timeStampSecVector = vectorSchemaRoot.getVector(" ts_sec" ) as TimeStampSecVector
527+ timeStampMilliVector.allocateNew(dates.size)
528+ timeStampNanoVector.allocateNew(dates.size)
529+ timeStampMicroVector.allocateNew(dates.size)
530+ timeStampSecVector.allocateNew(dates.size)
531+
532+ dates.forEachIndexed { index, localDateTime ->
533+ val instant = localDateTime.toInstant(ZoneOffset .UTC )
534+ timeStampNanoVector[index] = instant.toEpochMilli() * 1_000_000L + instant.nano
535+ timeStampMicroVector[index] = instant.toEpochMilli() * 1_000L
536+ timeStampMilliVector[index] = instant.toEpochMilli()
537+ timeStampSecVector[index] = instant.toEpochMilli() / 1_000L
538+ }
539+ vectorSchemaRoot.setRowCount(dates.size)
540+ val bos = ByteArrayOutputStream ()
541+ bos.use { out ->
542+ val arrowWriter = if (streaming) {
543+ ArrowStreamWriter (vectorSchemaRoot, null , Channels .newChannel(out ))
544+ } else {
545+ ArrowFileWriter (vectorSchemaRoot, null , Channels .newChannel(out ))
546+ }
547+ arrowWriter.use { writer ->
548+ writer.start()
549+ writer.writeBatch()
550+ }
551+ }
552+ return bos.toByteArray()
553+ }
554+ }
555+ }
462556}
0 commit comments