Skip to content

Commit 7e44783

Browse files
committed
Add read of Arrow TimeStamp without timezone as LocalDatetime #515
1 parent b3f3331 commit 7e44783

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed

dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrowReadingImpl.kt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ import org.apache.arrow.vector.TimeMicroVector
1818
import org.apache.arrow.vector.TimeMilliVector
1919
import org.apache.arrow.vector.TimeNanoVector
2020
import org.apache.arrow.vector.TimeSecVector
21+
import org.apache.arrow.vector.TimeStampMicroVector
22+
import org.apache.arrow.vector.TimeStampMilliVector
23+
import org.apache.arrow.vector.TimeStampNanoVector
24+
import org.apache.arrow.vector.TimeStampSecVector
2125
import org.apache.arrow.vector.TinyIntVector
2226
import org.apache.arrow.vector.UInt1Vector
2327
import org.apache.arrow.vector.UInt2Vector
@@ -130,6 +134,39 @@ private fun TimeMilliVector.values(range: IntRange): List<LocalTime?> = range.ma
130134

131135
private fun TimeSecVector.values(range: IntRange): List<LocalTime?> =
132136
range.map { getObject(it)?.let { LocalTime.ofSecondOfDay(it.toLong()) } }
137+
138+
private fun TimeStampNanoVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
139+
if (isNull(i)) {
140+
null
141+
} else {
142+
getObject(it)
143+
}
144+
}
145+
146+
private fun TimeStampMicroVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
147+
if (isNull(i)) {
148+
null
149+
} else {
150+
getObject(it)
151+
}
152+
}
153+
154+
private fun TimeStampMilliVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
155+
if (isNull(i)) {
156+
null
157+
} else {
158+
getObject(it)
159+
}
160+
}
161+
162+
private fun TimeStampSecVector.values(range: IntRange): List<LocalDateTime?> = range.mapIndexed { i, it ->
163+
if (isNull(i)) {
164+
null
165+
} else {
166+
getObject(it)
167+
}
168+
}
169+
133170
private fun StructVector.values(range: IntRange): List<Map<String, Any?>?> = range.map {
134171
getObject(it)
135172
}
@@ -202,6 +239,10 @@ private fun readField(root: VectorSchemaRoot, field: Field, nullability: Nullabi
202239
is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
203240
is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
204241
is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
242+
is TimeStampNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
243+
is TimeStampMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
244+
is TimeStampMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
245+
is TimeStampSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
205246
is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability)
206247
else -> {
207248
throw NotImplementedError("reading from ${vector.javaClass.canonicalName} is not implemented")

dataframe-arrow/src/test/kotlin/org/jetbrains/kotlinx/dataframe/io/ArrowKtTest.kt

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@ package org.jetbrains.kotlinx.dataframe.io
33
import io.kotest.assertions.throwables.shouldThrow
44
import io.kotest.matchers.collections.shouldContain
55
import io.kotest.matchers.shouldBe
6+
import org.apache.arrow.memory.RootAllocator
7+
import org.apache.arrow.vector.TimeStampMicroVector
8+
import org.apache.arrow.vector.TimeStampMilliVector
9+
import org.apache.arrow.vector.TimeStampNanoVector
10+
import org.apache.arrow.vector.TimeStampSecVector
11+
import org.apache.arrow.vector.VectorSchemaRoot
12+
import org.apache.arrow.vector.ipc.ArrowFileWriter
13+
import org.apache.arrow.vector.ipc.ArrowStreamWriter
614
import org.apache.arrow.vector.types.FloatingPointPrecision
15+
import org.apache.arrow.vector.types.TimeUnit
716
import org.apache.arrow.vector.types.pojo.ArrowType
817
import org.apache.arrow.vector.types.pojo.Field
918
import org.apache.arrow.vector.types.pojo.FieldType
@@ -23,10 +32,13 @@ import org.jetbrains.kotlinx.dataframe.api.remove
2332
import org.jetbrains.kotlinx.dataframe.api.toColumn
2433
import org.jetbrains.kotlinx.dataframe.exceptions.TypeConverterNotFoundException
2534
import org.junit.Test
35+
import java.io.ByteArrayOutputStream
2636
import java.io.File
2737
import java.net.URL
38+
import java.nio.channels.Channels
2839
import java.time.LocalDate
2940
import java.time.LocalDateTime
41+
import java.time.ZoneOffset
3042
import java.util.Locale
3143
import kotlin.reflect.typeOf
3244

@@ -273,4 +285,87 @@ internal class ArrowKtTest {
273285
val data = dataFrame.saveArrowFeatherToByteArray()
274286
DataFrame.readArrowFeather(data) shouldBe dataFrame
275287
}
288+
289+
@Test
290+
fun testTimeStamp(){
291+
val dates = listOf(
292+
LocalDateTime.of(2023,11,23,9,30,25),
293+
LocalDateTime.of(2015,5,25,14,20,13),
294+
LocalDateTime.of(2013,6,19,11,20,13)
295+
)
296+
297+
val dataFrame = dataFrameOf(
298+
"ts_nano" to dates,
299+
"ts_micro" to dates,
300+
"ts_milli" to dates,
301+
"ts_sec" to dates)
302+
303+
DataFrame.readArrowFeather(writeArrowTimestamp(dates)) shouldBe dataFrame
304+
DataFrame.readArrowIPC(writeArrowTimestamp(dates,true)) shouldBe dataFrame
305+
306+
}
307+
308+
private fun writeArrowTimestamp(dates:List<LocalDateTime>,streaming:Boolean=false) :ByteArray {
309+
RootAllocator().use { allocator ->
310+
val timeStampMilli = Field(
311+
"ts_milli",
312+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
313+
null
314+
)
315+
316+
val timeStampMicro = Field(
317+
"ts_micro",
318+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.MICROSECOND, null)),
319+
null
320+
)
321+
322+
val timeStampNano = Field(
323+
"ts_nano",
324+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.NANOSECOND, null)),
325+
null
326+
)
327+
328+
val timeStampSec = Field(
329+
"ts_sec",
330+
FieldType.nullable(ArrowType.Timestamp(TimeUnit.SECOND, null)),
331+
null
332+
)
333+
val schemaTimeStamp = Schema(
334+
listOf(timeStampNano,timeStampMicro,timeStampMilli,timeStampSec)
335+
)
336+
VectorSchemaRoot.create(schemaTimeStamp, allocator).use { vectorSchemaRoot ->
337+
val timeStampMilliVector = vectorSchemaRoot.getVector("ts_milli") as TimeStampMilliVector
338+
val timeStampNanoVector = vectorSchemaRoot.getVector("ts_nano") as TimeStampNanoVector
339+
val timeStampMicroVector = vectorSchemaRoot.getVector("ts_micro") as TimeStampMicroVector
340+
val timeStampSecVector = vectorSchemaRoot.getVector("ts_sec") as TimeStampSecVector
341+
timeStampMilliVector.allocateNew(dates.size)
342+
timeStampNanoVector.allocateNew(dates.size)
343+
timeStampMicroVector.allocateNew(dates.size)
344+
timeStampSecVector.allocateNew(dates.size)
345+
346+
347+
dates.forEachIndexed { index, localDateTime ->
348+
val instant = localDateTime.toInstant(ZoneOffset.UTC)
349+
timeStampNanoVector[index] = instant.toEpochMilli() * 1000000L + instant.nano
350+
timeStampMicroVector[index] = instant.toEpochMilli() * 1000L
351+
timeStampMilliVector[index] = instant.toEpochMilli()
352+
timeStampSecVector[index] = instant.toEpochMilli() / 1000
353+
}
354+
vectorSchemaRoot.setRowCount(dates.size)
355+
val bos = ByteArrayOutputStream()
356+
bos.use { out ->
357+
val arrowWriter = if(streaming){
358+
ArrowStreamWriter(vectorSchemaRoot, null, Channels.newChannel(out))
359+
}else{
360+
ArrowFileWriter(vectorSchemaRoot, null, Channels.newChannel(out))
361+
}
362+
arrowWriter .use { writer ->
363+
writer.start()
364+
writer.writeBatch()
365+
}
366+
}
367+
return bos.toByteArray()
368+
}
369+
}
370+
}
276371
}

0 commit comments

Comments
 (0)