diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/TypeConversions.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/TypeConversions.kt index 9bebaa60b8..1376043d1b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/TypeConversions.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/TypeConversions.kt @@ -17,6 +17,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnSet import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.ValueColumn import org.jetbrains.kotlinx.dataframe.impl.GroupByImpl +import org.jetbrains.kotlinx.dataframe.impl.anyNull import org.jetbrains.kotlinx.dataframe.impl.asList import org.jetbrains.kotlinx.dataframe.impl.columnName import org.jetbrains.kotlinx.dataframe.impl.columns.ColumnAccessorImpl @@ -189,6 +190,50 @@ public enum class Infer { Type } +/** + * Indicates how [DataColumn.hasNulls] (or, more accurately, DataColumn.type.isMarkedNullable) should be initialized from + * expected schema and actual data when reading schema-defined data formats. + */ +public enum class NullabilityOptions { + /** + * Use only actual data, set [DataColumn.hasNulls] to true if and only if there are null values in the column. + * On empty dataset use False. + */ + Infer, + + /** + * Set [DataColumn.hasNulls] to expected value. Throw exception if column should be not nullable but there are null values. + */ + Checking, + + /** + * Set [DataColumn.hasNulls] to expected value by default. Change False to True if column should be not nullable but there are null values. + */ + Widening +} + +public class NullabilityException() : Exception() + +/** + * @return if column should be marked nullable for current [NullabilityOptions] value with actual [data] and [expectedNulls] per some schema/signature. + * @throws [NullabilityException] for [NullabilityOptions.Checking] if [expectedNulls] is false and [data] contains nulls. + */ +public fun NullabilityOptions.applyNullability(data: List, expectedNulls: Boolean): Boolean { + val hasNulls = data.anyNull() + return when (this) { + NullabilityOptions.Infer -> hasNulls + NullabilityOptions.Checking -> { + if (!expectedNulls && hasNulls) { + throw NullabilityException() + } + expectedNulls + } + NullabilityOptions.Widening -> { + expectedNulls || hasNulls + } + } +} + public inline fun Iterable.toColumn( name: String = "", infer: Infer = Infer.None diff --git a/dataframe-arrow/build.gradle.kts b/dataframe-arrow/build.gradle.kts index 07a135589d..6f23654bf5 100644 --- a/dataframe-arrow/build.gradle.kts +++ b/dataframe-arrow/build.gradle.kts @@ -12,6 +12,7 @@ dependencies { implementation(libs.arrow.format) implementation(libs.arrow.memory) implementation(libs.commonsCompress) + implementation(libs.kotlin.reflect) testApi(project(":core")) testImplementation(libs.junit) diff --git a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt index 86c8e2e2e0..6394e3bce3 100644 --- a/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt +++ b/dataframe-arrow/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/arrow.kt @@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.io import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.BigIntVector import org.apache.arrow.vector.BitVector +import org.apache.arrow.vector.DateDayVector +import org.apache.arrow.vector.DateMilliVector import org.apache.arrow.vector.Decimal256Vector import org.apache.arrow.vector.DecimalVector import org.apache.arrow.vector.DurationVector @@ -28,16 +30,24 @@ import org.apache.arrow.vector.complex.StructVector import org.apache.arrow.vector.ipc.ArrowFileReader import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.arrow.vector.types.pojo.Field +import org.apache.arrow.vector.util.DateUtility import org.apache.commons.compress.utils.SeekableInMemoryByteChannel import org.jetbrains.kotlinx.dataframe.AnyBaseCol import org.jetbrains.kotlinx.dataframe.AnyFrame import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataFrame import org.jetbrains.kotlinx.dataframe.api.Infer -import org.jetbrains.kotlinx.dataframe.api.concat +import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions +import org.jetbrains.kotlinx.dataframe.api.applyNullability +import org.jetbrains.kotlinx.dataframe.api.NullabilityException +import org.jetbrains.kotlinx.dataframe.api.cast +import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.emptyDataFrame +import org.jetbrains.kotlinx.dataframe.api.getColumn import org.jetbrains.kotlinx.dataframe.api.toDataFrame import org.jetbrains.kotlinx.dataframe.codeGen.AbstractDefaultReadMethod import org.jetbrains.kotlinx.dataframe.codeGen.DefaultReadDfMethod +import org.jetbrains.kotlinx.dataframe.impl.asList import java.io.File import java.io.InputStream import java.math.BigDecimal @@ -48,13 +58,17 @@ import java.nio.channels.ReadableByteChannel import java.nio.channels.SeekableByteChannel import java.nio.file.Files import java.time.Duration +import java.time.LocalDate import java.time.LocalDateTime +import java.time.LocalTime +import kotlin.reflect.KType +import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf public class ArrowFeather : SupportedFormat { - override fun readDataFrame(stream: InputStream, header: List): AnyFrame = DataFrame.readArrowFeather(stream) + override fun readDataFrame(stream: InputStream, header: List): AnyFrame = DataFrame.readArrowFeather(stream, NullabilityOptions.Widening) - override fun readDataFrame(file: File, header: List): AnyFrame = DataFrame.readArrowFeather(file) + override fun readDataFrame(file: File, header: List): AnyFrame = DataFrame.readArrowFeather(file, NullabilityOptions.Widening) override fun acceptsExtension(ext: String): Boolean = ext == "feather" @@ -75,45 +89,64 @@ internal object Allocator { } } +/** + * same as [Iterable>.concat()] without internal type guessing (all batches should have the same schema) + */ +internal fun Iterable>.concatKeepingSchema(): DataFrame { + val dataFrames = asList() + when (dataFrames.size) { + 0 -> return emptyDataFrame() + 1 -> return dataFrames[0] + } + + val columnNames = dataFrames.first().columnNames() + + val columns = columnNames.map { name -> + val values = dataFrames.flatMap { it.getColumn(name).values() } + DataColumn.createValueColumn(name, values, dataFrames.first().getColumn(name).type()) + } + return dataFrameOf(columns).cast() +} + /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [channel] */ -public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame { +public fun DataFrame.Companion.readArrowIPC(channel: ReadableByteChannel, allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame { ArrowStreamReader(channel, allocator).use { reader -> val dfs = buildList { val root = reader.vectorSchemaRoot val schema = root.schema while (reader.loadNextBatch()) { - val df = schema.fields.map { f -> readField(root, f) }.toDataFrame() + val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() add(df) } } - return dfs.concat() + return dfs.concatKeepingSchema() } } /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [channel] */ -public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT): AnyFrame { +public fun DataFrame.Companion.readArrowFeather(channel: SeekableByteChannel, allocator: RootAllocator = Allocator.ROOT, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame { ArrowFileReader(channel, allocator).use { reader -> val dfs = buildList { reader.recordBlocks.forEach { block -> reader.loadRecordBatch(block) val root = reader.vectorSchemaRoot val schema = root.schema - val df = schema.fields.map { f -> readField(root, f) }.toDataFrame() + val df = schema.fields.map { f -> readField(root, f, nullability) }.toDataFrame() add(df) } } - return dfs.concat() + return dfs.concatKeepingSchema() } } private fun BitVector.values(range: IntRange): List = range.map { getObject(it) } -private fun UInt1Vector.values(range: IntRange): List = range.map { getObject(it) } -private fun UInt2Vector.values(range: IntRange): List = range.map { getObject(it) } +private fun UInt1Vector.values(range: IntRange): List = range.map { getObjectNoOverflow(it) } +private fun UInt2Vector.values(range: IntRange): List = range.map { getObject(it)?.code } private fun UInt4Vector.values(range: IntRange): List = range.map { getObjectNoOverflow(it) } private fun UInt8Vector.values(range: IntRange): List = range.map { getObjectNoOverflow(it) } @@ -129,10 +162,35 @@ private fun Float4Vector.values(range: IntRange): List = range.map { get private fun Float8Vector.values(range: IntRange): List = range.map { getObject(it) } private fun DurationVector.values(range: IntRange): List = range.map { getObject(it) } -private fun TimeNanoVector.values(range: IntRange): List = range.map { getObject(it) } -private fun TimeMicroVector.values(range: IntRange): List = range.map { getObject(it) } -private fun TimeMilliVector.values(range: IntRange): List = range.map { getObject(it) } -private fun TimeSecVector.values(range: IntRange): List = range.map { getObject(it) } +private fun DateDayVector.values(range: IntRange): List = range.map { + if (getObject(it) == null) null else + DateUtility.getLocalDateTimeFromEpochMilli(getObject(it).toLong() * DateUtility.daysToStandardMillis).toLocalDate() +} +private fun DateMilliVector.values(range: IntRange): List = range.map { getObject(it) } + +private fun TimeNanoVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + LocalTime.ofNanoOfDay(get(it)) + } +} +private fun TimeMicroVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + LocalTime.ofNanoOfDay(getObject(it) * 1000) + } +} +private fun TimeMilliVector.values(range: IntRange): List = range.mapIndexed { i, it -> + if (isNull(i)) { + null + } else { + LocalTime.ofNanoOfDay(get(it).toLong() * 1000_000) + } +} +private fun TimeSecVector.values(range: IntRange): List = range.map { getObject(it)?.let {LocalTime.ofSecondOfDay(it.toLong())} } + private fun StructVector.values(range: IntRange): List?> = range.map { getObject(it) } private fun VarCharVector.values(range: IntRange): List = range.map { @@ -167,39 +225,48 @@ private fun LargeVarCharVector.values(range: IntRange): List = range.ma } } -private inline fun List.withType() = this to typeOf() - -private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol { - val range = 0 until root.rowCount - val (list, type) = when (val vector = root.getVector(field)) { - is VarCharVector -> vector.values(range).withType() - is LargeVarCharVector -> vector.values(range).withType() - is VarBinaryVector -> vector.values(range).withType() - is LargeVarBinaryVector -> vector.values(range).withType() - is BitVector -> vector.values(range).withType() - is SmallIntVector -> vector.values(range).withType() - is TinyIntVector -> vector.values(range).withType() - is UInt1Vector -> vector.values(range).withType() - is UInt2Vector -> vector.values(range).withType() - is UInt4Vector -> vector.values(range).withType() - is UInt8Vector -> vector.values(range).withType() - is IntVector -> vector.values(range).withType() - is BigIntVector -> vector.values(range).withType() - is DecimalVector -> vector.values(range).withType() - is Decimal256Vector -> vector.values(range).withType() - is Float8Vector -> vector.values(range).withType() - is Float4Vector -> vector.values(range).withType() - is DurationVector -> vector.values(range).withType() - is TimeNanoVector -> vector.values(range).withType() - is TimeMicroVector -> vector.values(range).withType() - is TimeMilliVector -> vector.values(range).withType() - is TimeSecVector -> vector.values(range).withType() - is StructVector -> vector.values(range).withType() - else -> { - TODO("not fully implemented") +private inline fun List.withTypeNullable(expectedNulls: Boolean, nullabilityOptions: NullabilityOptions): Pair, KType> { + val nullable = nullabilityOptions.applyNullability(this, expectedNulls) + return this to typeOf().withNullability(nullable) +} + +private fun readField(root: VectorSchemaRoot, field: Field, nullability: NullabilityOptions): AnyBaseCol { + try { + val range = 0 until root.rowCount + val (list, type) = when (val vector = root.getVector(field)) { + is VarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is LargeVarCharVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is VarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is LargeVarBinaryVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is BitVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is SmallIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TinyIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is UInt1Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is UInt2Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is UInt4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is UInt8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is IntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is BigIntVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is DecimalVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is Decimal256Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is Float8Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is Float4Vector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is DurationVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is DateDayVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is DateMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeNanoVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeMicroVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeMilliVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is TimeSecVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + is StructVector -> vector.values(range).withTypeNullable(field.isNullable, nullability) + else -> { + TODO("not fully implemented") + } } + return DataColumn.createValueColumn(field.name, list, type, Infer.None) + } catch (unexpectedNull: NullabilityException) { + throw IllegalArgumentException("Column `${field.name}` should be not nullable but has nulls") } - return DataColumn.createValueColumn(field.name, list, type, Infer.Nulls) } // IPC reading block @@ -207,34 +274,37 @@ private fun readField(root: VectorSchemaRoot, field: Field): AnyBaseCol { /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [file] */ -public fun DataFrame.Companion.readArrowIPC(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowIPC(it) } +public fun DataFrame.Companion.readArrowIPC(file: File, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + Files.newByteChannel(file.toPath()).use { readArrowIPC(it, nullability = nullability) } /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [byteArray] */ -public fun DataFrame.Companion.readArrowIPC(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowIPC(it) } +public fun DataFrame.Companion.readArrowIPC(byteArray: ByteArray, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + SeekableInMemoryByteChannel(byteArray).use { readArrowIPC(it, nullability = nullability) } /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [stream] */ -public fun DataFrame.Companion.readArrowIPC(stream: InputStream): AnyFrame = Channels.newChannel(stream).use { readArrowIPC(it) } +public fun DataFrame.Companion.readArrowIPC(stream: InputStream, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + Channels.newChannel(stream).use { readArrowIPC(it, nullability = nullability) } /** * Read [Arrow interprocess streaming format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-streaming-format) data from existing [url] */ -public fun DataFrame.Companion.readArrowIPC(url: URL): AnyFrame = +public fun DataFrame.Companion.readArrowIPC(url: URL, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = when { - isFile(url) -> readArrowIPC(urlAsFile(url)) - isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it) } + isFile(url) -> readArrowIPC(urlAsFile(url), nullability) + isProtocolSupported(url) -> url.openStream().use { readArrowIPC(it, nullability) } else -> { throw IllegalArgumentException("Invalid protocol for url $url") } } -public fun DataFrame.Companion.readArrowIPC(path: String): AnyFrame = if (isURL(path)) { - readArrowIPC(URL(path)) +public fun DataFrame.Companion.readArrowIPC(path: String, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = if (isURL(path)) { + readArrowIPC(URL(path), nullability) } else { - readArrowIPC(File(path)) + readArrowIPC(File(path), nullability) } // Feather reading block @@ -242,25 +312,28 @@ public fun DataFrame.Companion.readArrowIPC(path: String): AnyFrame = if (isURL( /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [file] */ -public fun DataFrame.Companion.readArrowFeather(file: File): AnyFrame = Files.newByteChannel(file.toPath()).use { readArrowFeather(it) } +public fun DataFrame.Companion.readArrowFeather(file: File, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + Files.newByteChannel(file.toPath()).use { readArrowFeather(it, nullability = nullability) } /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [byteArray] */ -public fun DataFrame.Companion.readArrowFeather(byteArray: ByteArray): AnyFrame = SeekableInMemoryByteChannel(byteArray).use { readArrowFeather(it) } +public fun DataFrame.Companion.readArrowFeather(byteArray: ByteArray, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + SeekableInMemoryByteChannel(byteArray).use { readArrowFeather(it, nullability = nullability) } /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [stream] */ -public fun DataFrame.Companion.readArrowFeather(stream: InputStream): AnyFrame = readArrowFeather(stream.readBytes()) +public fun DataFrame.Companion.readArrowFeather(stream: InputStream, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = + readArrowFeather(stream.readBytes(), nullability) /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [url] */ -public fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame = +public fun DataFrame.Companion.readArrowFeather(url: URL, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = when { - isFile(url) -> readArrowFeather(urlAsFile(url)) - isProtocolSupported(url) -> readArrowFeather(url.readBytes()) + isFile(url) -> readArrowFeather(urlAsFile(url), nullability) + isProtocolSupported(url) -> readArrowFeather(url.readBytes(), nullability) else -> { throw IllegalArgumentException("Invalid protocol for url $url") } @@ -269,8 +342,8 @@ public fun DataFrame.Companion.readArrowFeather(url: URL): AnyFrame = /** * Read [Arrow random access format](https://arrow.apache.org/docs/java/ipc.html#writing-and-reading-random-access-files) data from existing [path] */ -public fun DataFrame.Companion.readArrowFeather(path: String): AnyFrame = if (isURL(path)) { - readArrowFeather(URL(path)) +public fun DataFrame.Companion.readArrowFeather(path: String, nullability: NullabilityOptions = NullabilityOptions.Infer): AnyFrame = if (isURL(path)) { + readArrowFeather(URL(path), nullability) } else { - readArrowFeather(File(path)) + readArrowFeather(File(path), nullability) } diff --git a/dataframe-arrow/src/test/kotlin/ArrowKtTest.kt b/dataframe-arrow/src/test/kotlin/ArrowKtTest.kt index 97d53804f3..813a7425b3 100644 --- a/dataframe-arrow/src/test/kotlin/ArrowKtTest.kt +++ b/dataframe-arrow/src/test/kotlin/ArrowKtTest.kt @@ -1,10 +1,13 @@ +import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe import org.apache.arrow.vector.util.Text import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.api.NullabilityOptions import org.jetbrains.kotlinx.dataframe.api.columnOf import org.jetbrains.kotlinx.dataframe.api.dataFrameOf import org.jetbrains.kotlinx.dataframe.api.toColumn import org.jetbrains.kotlinx.dataframe.io.readArrowFeather +import org.jetbrains.kotlinx.dataframe.io.readArrowIPC import org.junit.Test import java.net.URL @@ -13,6 +16,7 @@ internal class ArrowKtTest { fun testResource(resourcePath: String): URL = ArrowKtTest::class.java.classLoader.getResource(resourcePath)!! fun testArrowFeather(name: String) = testResource("$name.feather") + fun testArrowIPC(name: String) = testResource("$name.ipc") @Test fun testReadingFromFile() { @@ -31,4 +35,56 @@ internal class ArrowKtTest { val expected = dataFrameOf(a, b, c, d) df shouldBe expected } + + @Test + fun testReadingAllTypesAsEstimated() { + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow"), NullabilityOptions.Infer), false, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow"), NullabilityOptions.Infer), false, false) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow"), NullabilityOptions.Checking), true, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow"), NullabilityOptions.Checking), true, false) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test.arrow"), NullabilityOptions.Widening), true, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test.arrow"), NullabilityOptions.Widening), true, false) + } + + @Test + fun testReadingAllTypesAsEstimatedWithNulls() { + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-with-nulls.arrow"), NullabilityOptions.Infer), true, true) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-with-nulls.arrow"), NullabilityOptions.Infer), true, true) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-with-nulls.arrow"), NullabilityOptions.Checking), true, true) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-with-nulls.arrow"), NullabilityOptions.Checking), true, true) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-with-nulls.arrow"), NullabilityOptions.Widening), true, true) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-with-nulls.arrow"), NullabilityOptions.Widening), true, true) + } + + @Test + fun testReadingAllTypesAsEstimatedNotNullable() { + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-not-nullable.arrow"), NullabilityOptions.Infer), false, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-not-nullable.arrow"), NullabilityOptions.Infer), false, false) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-not-nullable.arrow"), NullabilityOptions.Checking), false, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-not-nullable.arrow"), NullabilityOptions.Checking), false, false) + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-not-nullable.arrow"), NullabilityOptions.Widening), false, false) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-not-nullable.arrow"), NullabilityOptions.Widening), false, false) + } + + @Test + fun testReadingAllTypesAsEstimatedNotNullableWithNulls() { + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-illegal.arrow"), NullabilityOptions.Infer), true, true) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-illegal.arrow"), NullabilityOptions.Infer), true, true) + + shouldThrow { + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-illegal.arrow"), NullabilityOptions.Checking), false, true) + } + shouldThrow { + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-illegal.arrow"), NullabilityOptions.Checking), false, true) + } + + assertEstimations(DataFrame.readArrowFeather(testArrowFeather("test-illegal.arrow"), NullabilityOptions.Widening), true, true) + assertEstimations(DataFrame.readArrowIPC(testArrowIPC("test-illegal.arrow"), NullabilityOptions.Widening), true, true) + } } diff --git a/dataframe-arrow/src/test/kotlin/exampleEstimatesAssertions.kt b/dataframe-arrow/src/test/kotlin/exampleEstimatesAssertions.kt new file mode 100644 index 0000000000..161753cede --- /dev/null +++ b/dataframe-arrow/src/test/kotlin/exampleEstimatesAssertions.kt @@ -0,0 +1,159 @@ +import io.kotest.matchers.shouldBe +import org.jetbrains.kotlinx.dataframe.AnyFrame +import org.jetbrains.kotlinx.dataframe.DataColumn +import org.jetbrains.kotlinx.dataframe.api.forEachIndexed +import java.math.BigInteger +import java.time.LocalDate +import java.time.LocalDateTime +import java.time.LocalTime +import java.time.ZoneOffset +import kotlin.math.absoluteValue +import kotlin.math.pow +import kotlin.reflect.full.withNullability +import kotlin.reflect.typeOf + +/** + * Assert that we have got the same data that was originally saved on example creation. + */ +internal fun assertEstimations(exampleFrame: AnyFrame, expectedNullable: Boolean, hasNulls: Boolean) { + /** + * In [exampleFrame] we get two concatenated batches. To assert the estimations, we should transform frame row number to batch row number + */ + fun iBatch(iFrame: Int): Int { + val firstBatchSize = 100; + return if (iFrame < firstBatchSize) iFrame else iFrame - firstBatchSize + } + + fun expectedNull(rowNumber: Int): Boolean { + return (rowNumber + 1) % 5 == 0; + } + + fun assertValueOrNull(rowNumber: Int, actual: Any?, expected: Any) { + if (hasNulls && expectedNull(rowNumber)) { + actual shouldBe null + } else { + actual shouldBe expected + } + } + + val asciiStringCol = exampleFrame["asciiString"] as DataColumn + asciiStringCol.type() shouldBe typeOf().withNullability(expectedNullable) + asciiStringCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, "Test Example ${iBatch(i)}") + } + + val utf8StringCol = exampleFrame["utf8String"] as DataColumn + utf8StringCol.type() shouldBe typeOf().withNullability(expectedNullable) + utf8StringCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, "Тестовый пример ${iBatch(i)}") + } + + val largeStringCol = exampleFrame["largeString"] as DataColumn + largeStringCol.type() shouldBe typeOf().withNullability(expectedNullable) + largeStringCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, "Test Example Should Be Large ${iBatch(i)}") + } + + val booleanCol = exampleFrame["boolean"] as DataColumn + booleanCol.type() shouldBe typeOf().withNullability(expectedNullable) + booleanCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i) % 2 == 0) + } + + val byteCol = exampleFrame["byte"] as DataColumn + byteCol.type() shouldBe typeOf().withNullability(expectedNullable) + byteCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, (iBatch(i) * 10).toByte()) + } + + val shortCol = exampleFrame["short"] as DataColumn + shortCol.type() shouldBe typeOf().withNullability(expectedNullable) + shortCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, (iBatch(i) * 1000).toShort()) + } + + val intCol = exampleFrame["int"] as DataColumn + intCol.type() shouldBe typeOf().withNullability(expectedNullable) + intCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i) * 100000000) + } + + val longCol = exampleFrame["longInt"] as DataColumn + longCol.type() shouldBe typeOf().withNullability(expectedNullable) + longCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i) * 100000000000000000L) + } + + val unsignedByteCol = exampleFrame["unsigned_byte"] as DataColumn + unsignedByteCol.type() shouldBe typeOf().withNullability(expectedNullable) + unsignedByteCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, (iBatch(i) * 10 % (Byte.MIN_VALUE.toShort() * 2).absoluteValue).toShort()) + } + + val unsignedShortCol = exampleFrame["unsigned_short"] as DataColumn + unsignedShortCol.type() shouldBe typeOf().withNullability(expectedNullable) + unsignedShortCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i) * 1000 % (Short.MIN_VALUE.toInt() * 2).absoluteValue) + } + + val unsignedIntCol = exampleFrame["unsigned_int"] as DataColumn + unsignedIntCol.type() shouldBe typeOf().withNullability(expectedNullable) + unsignedIntCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i).toLong() * 100000000 % (Int.MIN_VALUE.toLong() * 2).absoluteValue) + } + + val unsignedLongIntCol = exampleFrame["unsigned_longInt"] as DataColumn + unsignedLongIntCol.type() shouldBe typeOf().withNullability(expectedNullable) + unsignedLongIntCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, iBatch(i).toBigInteger() * 100000000000000000L.toBigInteger() % (Long.MIN_VALUE.toBigInteger() * 2.toBigInteger()).abs()) + } + + val floatCol = exampleFrame["float"] as DataColumn + floatCol.type() shouldBe typeOf().withNullability(expectedNullable) + floatCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, 2.0f.pow(iBatch(i).toFloat())) + } + + val doubleCol = exampleFrame["double"] as DataColumn + doubleCol.type() shouldBe typeOf().withNullability(expectedNullable) + doubleCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, 2.0.pow(iBatch(i))) + } + + val dateCol = exampleFrame["date32"] as DataColumn + dateCol.type() shouldBe typeOf().withNullability(expectedNullable) + dateCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalDate.ofEpochDay(iBatch(i).toLong() * 30)) + } + + val datetimeCol = exampleFrame["date64"] as DataColumn + datetimeCol.type() shouldBe typeOf().withNullability(expectedNullable) + datetimeCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalDateTime.ofEpochSecond(iBatch(i).toLong() * 60 * 60 * 24 * 30, 0, ZoneOffset.UTC)) + } + + val timeSecCol = exampleFrame["time32_seconds"] as DataColumn + timeSecCol.type() shouldBe typeOf().withNullability(expectedNullable) + timeSecCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalTime.ofSecondOfDay(iBatch(i).toLong())) + } + + val timeMilliCol = exampleFrame["time32_milli"] as DataColumn + timeMilliCol.type() shouldBe typeOf().withNullability(expectedNullable) + timeMilliCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000_000)) + } + + val timeMicroCol = exampleFrame["time64_micro"] as DataColumn + timeMicroCol.type() shouldBe typeOf().withNullability(expectedNullable) + timeMicroCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong() * 1000)) + } + + val timeNanoCol = exampleFrame["time64_nano"] as DataColumn + timeNanoCol.type() shouldBe typeOf().withNullability(expectedNullable) + timeNanoCol.forEachIndexed { i, element -> + assertValueOrNull(iBatch(i), element, LocalTime.ofNanoOfDay(iBatch(i).toLong())) + } + +} diff --git a/dataframe-arrow/src/test/resources/test-illegal.arrow.feather b/dataframe-arrow/src/test/resources/test-illegal.arrow.feather new file mode 100644 index 0000000000..eddaf458ca Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-illegal.arrow.feather differ diff --git a/dataframe-arrow/src/test/resources/test-illegal.arrow.ipc b/dataframe-arrow/src/test/resources/test-illegal.arrow.ipc new file mode 100644 index 0000000000..de09051e67 Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-illegal.arrow.ipc differ diff --git a/dataframe-arrow/src/test/resources/test-not-nullable.arrow.feather b/dataframe-arrow/src/test/resources/test-not-nullable.arrow.feather new file mode 100644 index 0000000000..c807ad9d84 Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-not-nullable.arrow.feather differ diff --git a/dataframe-arrow/src/test/resources/test-not-nullable.arrow.ipc b/dataframe-arrow/src/test/resources/test-not-nullable.arrow.ipc new file mode 100644 index 0000000000..228cdbfc5d Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-not-nullable.arrow.ipc differ diff --git a/dataframe-arrow/src/test/resources/test-with-nulls.arrow.feather b/dataframe-arrow/src/test/resources/test-with-nulls.arrow.feather new file mode 100644 index 0000000000..129128f9f6 Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-with-nulls.arrow.feather differ diff --git a/dataframe-arrow/src/test/resources/test-with-nulls.arrow.ipc b/dataframe-arrow/src/test/resources/test-with-nulls.arrow.ipc new file mode 100644 index 0000000000..0db25e66fc Binary files /dev/null and b/dataframe-arrow/src/test/resources/test-with-nulls.arrow.ipc differ diff --git a/dataframe-arrow/src/test/resources/test.arrow.feather b/dataframe-arrow/src/test/resources/test.arrow.feather new file mode 100644 index 0000000000..4a348d1e2d Binary files /dev/null and b/dataframe-arrow/src/test/resources/test.arrow.feather differ diff --git a/dataframe-arrow/src/test/resources/test.arrow.ipc b/dataframe-arrow/src/test/resources/test.arrow.ipc new file mode 100644 index 0000000000..61e8c31afa Binary files /dev/null and b/dataframe-arrow/src/test/resources/test.arrow.ipc differ