diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 548fc47048..af9bea3657 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dataframe.impl.zero import org.jetbrains.kotlinx.dataframe.math.sum import org.jetbrains.kotlinx.dataframe.math.sumOf import kotlin.reflect.KProperty +import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.typeOf // region DataColumn @@ -42,7 +43,11 @@ public inline fun DataColumn.sumOf(crossinline expres // region DataRow -public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateMixed(values().filterIsInstance()) ?: 0 +public fun AnyRow.rowSum(): Number = + Aggregators.sum.aggregateMixed( + values = values().filterIsInstance(), + types = columnTypes().filter { it.isSubtypeOf(typeOf()) }.toSet(), + ) ?: 0 public inline fun AnyRow.rowSumOf(): T = values().filterIsInstance().sum(typeOf()) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt new file mode 100644 index 0000000000..42db06463c --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/documentation/UnifyingNumbers.kt @@ -0,0 +1,37 @@ +package org.jetbrains.kotlinx.dataframe.documentation + +/** + * ## Unifying Numbers + * + * The concept of unifying numbers is converting them to a common number type without losing information. + * + * The following graph shows the hierarchy of number types in Kotlin DataFrame. + * The order is top-down from the most complex type to the simplest one. + * + * {@include [Graph]} + * For each number type in the graph, it holds that a number of that type can be expressed lossless by + * a number of a more complex type (any of its parents). + * This is either because the more complex type has a larger range or higher precision (in terms of bits). + */ +internal interface UnifyingNumbers { + + /** + * ``` + * BigDecimal + * / \\ + * BigInteger \\ + * / \\ \\ + * ULong Long Double + * .. | / | / | \\.. + * \\ | / | / | + * UInt Int Float + * .. | / | / \\.. + * \\ | / | / + * UShort Short + * | / | + * | / | + * UByte Byte + * ``` + */ + interface Graph +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraph.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraph.kt new file mode 100644 index 0000000000..486c64bcaa --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraph.kt @@ -0,0 +1,176 @@ +package org.jetbrains.kotlinx.dataframe.impl + +import kotlin.experimental.ExperimentalTypeInference + +/** + * Represents a directed acyclic graph (DAG) of generic type [T]. + * + * This class is immutable and guarantees that the graph does not contain any cycles. + * It provides functionality to find the nearest common ancestor of two vertices + * in the graph ([findNearestCommonVertex]). + * + * Use the [Builder] class or [buildDag] function to create a new instance of this class. + * + * @param T The type of items in the graph. + * @property adjacencyList A map representing directed edges, where the keys are source vertices + * and the values are sets of destination vertices. + * @property vertices A set of all vertices in the graph. + */ +internal class DirectedAcyclicGraph private constructor( + private val adjacencyList: Map>, + private val vertices: Set, +) { + class Builder { + private val edges = mutableListOf>() + private val vertices = mutableSetOf() + + fun addEdge(from: T, to: T): Builder { + edges.add(from to to) + vertices.add(from) + vertices.add(to) + return this + } + + fun addEdges(vararg edges: Pair): Builder { + edges.forEach { (from, to) -> addEdge(from, to) } + return this + } + + fun build(): DirectedAcyclicGraph { + val adjacencyList = edges.groupBy({ it.first }, { it.second }) + .mapValues { it.value.toSet() } + + if (hasCycle(adjacencyList)) { + throw IllegalStateException("Graph contains cycle") + } + + return DirectedAcyclicGraph(adjacencyList, vertices) + } + + private fun hasCycle(adjacencyList: Map>): Boolean { + val visited = mutableSetOf() + val recursionStack = mutableSetOf() + + fun dfs(vertex: T): Boolean { + if (vertex in recursionStack) return true + if (vertex in visited) return false + + visited.add(vertex) + recursionStack.add(vertex) + + adjacencyList[vertex]?.forEach { neighbor -> + if (dfs(neighbor)) return true + } + + recursionStack.remove(vertex) + return false + } + + return adjacencyList.keys.any { vertex -> + if (vertex !in visited && dfs(vertex)) return true + false + } + } + } + + fun findNearestCommonVertex(vertex1: T, vertex2: T): T? { + if (vertex1 !in vertices || vertex2 !in vertices) return null + if (vertex1 == vertex2) return vertex1 + + // Get all ancestors for both vertices + val ancestors1 = getAllAncestors(vertex1) + val ancestors2 = getAllAncestors(vertex2) + + // If one vertex is an ancestor of another, return that vertex + if (vertex1 in ancestors2) return vertex1 + if (vertex2 in ancestors1) return vertex2 + + // Find common ancestors + val commonAncestors = ancestors1.intersect(ancestors2) + if (commonAncestors.isEmpty()) return null + + // Find the nearest common ancestor by checking distance from both vertices + return commonAncestors.minByOrNull { ancestor -> + getDistance(ancestor, vertex1) + getDistance(ancestor, vertex2) + } + } + + private fun getAllAncestors(vertex: T): Set { + val ancestors = mutableSetOf() + val visited = mutableSetOf() + + fun dfs(current: T) { + if (current in visited) return + visited.add(current) + + adjacencyList.forEach { (parent, children) -> + if (current in children) { + ancestors.add(parent) + dfs(parent) + } + } + } + + dfs(vertex) + return ancestors + } + + private fun getDistance(from: T, to: T): Int { + if (from == to) return 0 + + val distances = mutableMapOf() + val queue = ArrayDeque() + + queue.add(from) + distances[from] = 0 + + while (queue.isNotEmpty()) { + val current = queue.removeFirst() + val currentDistance = distances[current] ?: continue + + adjacencyList[current]?.forEach { neighbor -> + if (neighbor !in distances) { + distances[neighbor] = currentDistance + 1 + queue.add(neighbor) + if (neighbor == to) return currentDistance + 1 + } + } + } + + return Int.MAX_VALUE + } + + fun map(conversion: (T) -> R): DirectedAcyclicGraph { + val cache = mutableMapOf() + val cachedConversion: (T) -> R = { cache.getOrPut(it) { conversion(it) } } + + return Builder().apply { + for ((from, to) in adjacencyList) { + for (to in to) { + addEdge(from = cachedConversion(from), to = cachedConversion(to)) + } + } + }.build() + } + + companion object { + fun builder(): Builder = Builder() + } +} + +/** + * Builds a new [DirectedAcyclicGraph] using the provided [builder] function. + * + * @see DirectedAcyclicGraph + */ +@OptIn(ExperimentalTypeInference::class) +internal fun buildDag( + @BuilderInference builder: DirectedAcyclicGraph.Builder.() -> Unit, +): DirectedAcyclicGraph = DirectedAcyclicGraph.builder().apply(builder).build() + +/** + * Builds a new [DirectedAcyclicGraph] using the provided [edges]. + * + * @see DirectedAcyclicGraph + */ +internal fun dagOf(vararg edges: Pair): DirectedAcyclicGraph = buildDag { addEdges(*edges) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt new file mode 100644 index 0000000000..06f2a92a74 --- /dev/null +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -0,0 +1,139 @@ +package org.jetbrains.kotlinx.dataframe.impl + +import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers +import org.jetbrains.kotlinx.dataframe.impl.api.createConverter +import java.math.BigDecimal +import java.math.BigInteger +import kotlin.reflect.KClass +import kotlin.reflect.KType +import kotlin.reflect.full.withNullability +import kotlin.reflect.typeOf + +/** + * Number type graph, structured in terms of number complexity. + * A number can always be expressed lossless by a number of a more complex type (any of its parents). + * + * {@include [UnifyingNumbers.Graph]} + * + * For any two numbers, we can find the nearest common ancestor in this graph + * by calling [DirectedAcyclicGraph.findNearestCommonVertex]. + * @see getUnifiedNumberClass + * @see unifiedNumberClass + * @see UnifyingNumbers + */ +internal val unifiedNumberTypeGraph: DirectedAcyclicGraph by lazy { + buildDag { + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + + addEdge(typeOf(), typeOf()) + addEdge(typeOf(), typeOf()) + } +} + +/** @include [unifiedNumberTypeGraph] */ +internal val unifiedNumberClassGraph: DirectedAcyclicGraph> by lazy { + unifiedNumberTypeGraph.map { it.classifier as KClass<*> } +} + +/** + * Determines the nearest common numeric type, in terms of complexity, between two given classes/types. + * + * Unsigned types are supported too even though they are not a [Number] instance, + * but unless two unsigned types are provided in the input, it will never be returned. + * Meaning, a single [Number] input, the output will always be a [Number]. + * + * @param first The first numeric type to compare. Can be null, in which case the second to is returned. + * @param second The second numeric to compare. Cannot be null. + * @return The nearest common numeric type between the two input classes. + * If no common class is found, [IllegalStateException] is thrown. + * @see UnifyingNumbers + */ +internal fun getUnifiedNumberType(first: KType?, second: KType): KType { + if (first == null) return second + + val firstWithoutNullability = first.withNullability(false) + val secondWithoutNullability = second.withNullability(false) + + val result = if (firstWithoutNullability == secondWithoutNullability) { + firstWithoutNullability + } else { + unifiedNumberTypeGraph.findNearestCommonVertex(firstWithoutNullability, secondWithoutNullability) + ?: error("Can not find common number type for $first and $second") + } + + return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result +} + +/** @include [getUnifiedNumberType] */ +@Suppress("IntroduceWhenSubject") +internal fun getUnifiedNumberClass(first: KClass<*>?, second: KClass<*>): KClass<*> = + when { + first == null -> second + + first == second -> first + + else -> unifiedNumberClassGraph.findNearestCommonVertex(first, second) + ?: error("Can not find common number type for $first and $second") + } + +/** + * Determines the nearest common numeric type, in terms of complexity, all types in [this]. + * + * Unsigned types are supported too even though they are not a [Number] instance, + * but unless the input solely exists of unsigned numbers, it will never be returned. + * Meaning, given a [Number] in the input, the output will always be a [Number]. + * + * @return The nearest common numeric type between the input types. + * If no common type is found, it returns [Number]. + * @see UnifyingNumbers + */ +internal fun Iterable.unifiedNumberType(): KType = + fold(null as KType?, ::getUnifiedNumberType) ?: typeOf() + +/** @include [unifiedNumberType] */ +internal fun Iterable>.unifiedNumberClass(): KClass<*> = + fold(null as KClass<*>?, ::getUnifiedNumberClass) ?: Number::class + +/** + * Converts the elements of the given iterable of numbers into a common numeric type based on complexity. + * The common numeric type is determined using the provided [commonNumberType] parameter + * or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. + * + * @param commonNumberType The desired common numeric type to convert the elements to. + * This is determined by default using the types of the elements in the iterable. + * @return A new iterable of numbers where each element is converted to the specified or inferred common number type. + * @throws IllegalStateException if an element cannot be converted to the common number type. + * @see UnifyingNumbers + */ +@Suppress("UNCHECKED_CAST") +internal fun Iterable.convertToUnifiedNumberType( + commonNumberType: KType = this.types().unifiedNumberType(), +): Iterable { + val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? + return map { + converter(it) ?: error("Can not convert $it to $commonNumberType") + } +} diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt index 486fce32a8..38be1760bf 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt @@ -167,35 +167,6 @@ internal fun resolve(actualType: KType, declaredType: KType): Map, KClass<*>>, KClass<*>> by lazy { - val map = mutableMapOf, KClass<*>>, KClass<*>>() - - fun add(from: KClass<*>, to: KClass<*>) { - map[from to to] = to - map[to to from] = to - } - - val intTypes = listOf(Byte::class, Short::class, Int::class, Long::class) - for (i in intTypes.indices) { - for (j in i + 1 until intTypes.size) { - add(intTypes[i], intTypes[j]) - } - add(intTypes[i], Double::class) - } - add(Float::class, Double::class) - map -} - -internal fun getCommonNumberType(first: KClass<*>?, second: KClass<*>): KClass<*> = - when { - first == null -> second - first == second -> first - else -> numberTypeExtensions[first to second] ?: error("Can not find common number type for $first and $second") - } - -internal fun Iterable>.commonNumberClass(): KClass<*> = - fold(null as KClass<*>?, ::getCommonNumberType) ?: Number::class - internal fun commonParent(classes: Iterable>): KClass<*>? = commonParents(classes).withMostSuperclasses() internal fun commonParent(vararg classes: KClass<*>): KClass<*>? = commonParent(classes.toList()) @@ -648,3 +619,27 @@ internal fun Any.asArrayAsListOrNull(): List<*>? = } internal fun Any.isBigNumber(): Boolean = this is BigInteger || this is BigDecimal + +/** + * Returns a set containing the [KClass] of each element in the iterable. + * + * This can be a heavy operation! + * + * The [KClass] is determined by retrieving the runtime class of each element. + * + * @return A set of [KClass] objects representing the runtime types of elements in the iterable. + */ +internal fun Iterable.classes(): Set> = mapTo(mutableSetOf()) { it::class } + +/** + * Returns a set of [KType] objects representing the star-projected types of the runtime classes + * of all unique elements in the iterable. + * + * The method internally relies on the [classes] function to collect the runtime classes of the + * elements in the iterable and then maps each class to its star-projected type. + * + * This can be a heavy operation! + * + * @return A set of [KType] objects corresponding to the star-projected runtime types of elements in the iterable. + */ +internal fun Iterable.types(): Set = classes().mapTo(mutableSetOf()) { it.createStarProjectedType(false) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 81ea4424c9..11da3971ac 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -1,6 +1,5 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators.std import org.jetbrains.kotlinx.dataframe.math.mean import org.jetbrains.kotlinx.dataframe.math.median import org.jetbrains.kotlinx.dataframe.math.std diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/NumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/NumbersAggregator.kt index db99322cc5..00ef22febe 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/NumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/NumbersAggregator.kt @@ -1,26 +1,36 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators import org.jetbrains.kotlinx.dataframe.DataColumn -import org.jetbrains.kotlinx.dataframe.impl.commonNumberClass -import org.jetbrains.kotlinx.dataframe.impl.createStarProjectedType +import org.jetbrains.kotlinx.dataframe.impl.convertToUnifiedNumberType +import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberType import kotlin.reflect.KProperty import kotlin.reflect.KType -internal class NumbersAggregator(name: String, aggregate: (Iterable, KType) -> C?) : - AggregatorBase(name, aggregate) { +internal class NumbersAggregator(name: String, aggregate: (Iterable, KType) -> Number?) : + AggregatorBase(name, aggregate) { - override fun aggregate(columns: Iterable>): C? = aggregateMixed(columns.mapNotNull { aggregate(it) }) + override fun aggregate(columns: Iterable>): Number? = + aggregateMixed( + values = columns.mapNotNull { aggregate(it) }, + types = columns.map { it.type() }.toSet(), + ) class Factory(private val aggregate: Iterable.(KType) -> Number?) : AggregatorProvider { override fun create(name: String) = NumbersAggregator(name, aggregate) - override operator fun getValue(obj: Any?, property: KProperty<*>): NumbersAggregator = - create(property.name) + override operator fun getValue(obj: Any?, property: KProperty<*>): NumbersAggregator = create(property.name) } - fun aggregateMixed(values: Iterable): C? { - val classes = values.map { it.javaClass.kotlin } - return aggregate(values, classes.commonNumberClass().createStarProjectedType(false)) + /** + * Can aggregate numbers with different types by first converting them to a compatible type. + */ + @Suppress("UNCHECKED_CAST") + fun aggregateMixed(values: Iterable, types: Set): Number? { + val commonType = types.unifiedNumberType() + return aggregate( + values = values.convertToUnifiedNumberType(commonType), + type = commonType, + ) } override val preservesType = false diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraphTest.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraphTest.kt new file mode 100644 index 0000000000..aeabf98d89 --- /dev/null +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/impl/DirectedAcyclicGraphTest.kt @@ -0,0 +1,89 @@ +package org.jetbrains.kotlinx.dataframe.impl + +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.nulls.shouldBeNull +import io.kotest.matchers.shouldBe +import org.junit.Test + +class DirectedAcyclicGraphTest { + + @Test + fun `basic graph building`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("B", "C") + } + + graph.findNearestCommonVertex("B", "C") shouldBe "B" + graph.findNearestCommonVertex("A", "C") shouldBe "A" + } + + @Test + fun `cycle detection`() { + shouldThrow { + buildDag { + addEdge("A", "B") + addEdge("B", "C") + addEdge("C", "A") + } + } + } + + @Test + fun `nearest common vertex - same vertex`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("B", "C") + } + + graph.findNearestCommonVertex("B", "B") shouldBe "B" + } + + @Test + fun `nearest common vertex - one is ancestor`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("B", "C") + addEdge("B", "D") + } + + graph.findNearestCommonVertex("B", "D") shouldBe "B" + graph.findNearestCommonVertex("D", "B") shouldBe "B" + } + + @Test + fun `nearest common vertex - common ancestor exists`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("A", "C") + addEdge("B", "D") + addEdge("C", "E") + } + + graph.findNearestCommonVertex("D", "E") shouldBe "A" + } + + @Test + fun `nearest common vertex - no common ancestor`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("C", "D") + } + + graph.findNearestCommonVertex("B", "D").shouldBeNull() + } + + @Test + fun `nearest common vertex - complex case`() { + val graph = buildDag { + addEdge("A", "B") + addEdge("B", "C") + addEdge("A", "D") + addEdge("D", "E") + addEdge("B", "E") + } + + // B is closer to E than A + graph.findNearestCommonVertex("C", "E") shouldBe "B" + } +} diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt index b93d8ab705..513d7f4d19 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt @@ -4,9 +4,11 @@ import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.api.columnOf import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.rowSum import org.jetbrains.kotlinx.dataframe.api.sum import org.jetbrains.kotlinx.dataframe.api.sumOf import org.junit.Test +import java.math.BigDecimal class SumTests { @@ -61,4 +63,24 @@ class SumTests { df.sum { value2 } shouldBe expected2 df.sum { value3 } shouldBe expected3 } + + /** [Issue #1068](https://github.com/Kotlin/dataframe/issues/1068) */ + @Test + fun `rowSum mixed number types`() { + dataFrameOf("a", "b")(1, 2f)[0].rowSum().let { + it shouldBe 3.0 + it::class shouldBe Double::class + } + + // NOTE! unsigned numbers are not Number, they are skipped for now + dataFrameOf("a", "b")(1, 2u)[0].rowSum().let { + it shouldBe 1 + it::class shouldBe Int::class + } + + dataFrameOf("a", "b")(1.0, 2L)[0].rowSum().let { + it shouldBe (3.0.toBigDecimal()) + it::class shouldBe BigDecimal::class + } + } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt index 57bcd7b78b..3518039dbf 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt @@ -4,12 +4,14 @@ import io.kotest.matchers.shouldBe import org.jetbrains.kotlinx.dataframe.DataColumn import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.api.columnOf +import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers import org.jetbrains.kotlinx.dataframe.impl.asArrayAsListOrNull import org.jetbrains.kotlinx.dataframe.impl.commonParent import org.jetbrains.kotlinx.dataframe.impl.commonParents import org.jetbrains.kotlinx.dataframe.impl.commonType import org.jetbrains.kotlinx.dataframe.impl.commonTypeListifyValues import org.jetbrains.kotlinx.dataframe.impl.createType +import org.jetbrains.kotlinx.dataframe.impl.getUnifiedNumberClass import org.jetbrains.kotlinx.dataframe.impl.guessValueType import org.jetbrains.kotlinx.dataframe.impl.isArray import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveArray @@ -17,6 +19,8 @@ import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.replaceGenericTypeParametersWithUpperbound import org.junit.Test import java.io.Serializable +import java.math.BigDecimal +import java.math.BigInteger import kotlin.reflect.KClass import kotlin.reflect.KType import kotlin.reflect.typeOf @@ -414,4 +418,48 @@ class UtilTests { typeOf?>(), ).commonTypeListifyValues() shouldBe typeOf?>() } + + /** + * See [UnifyingNumbers] for more information. + * {@include [UnifyingNumbers.Graph]} + */ + @Test + fun `common number types`() { + // Same type + getUnifiedNumberClass(Int::class, Int::class) shouldBe Int::class + getUnifiedNumberClass(Double::class, Double::class) shouldBe Double::class + + // Direct parent-child relationships + getUnifiedNumberClass(Int::class, UShort::class) shouldBe Int::class + getUnifiedNumberClass(Long::class, UInt::class) shouldBe Long::class + getUnifiedNumberClass(Double::class, Float::class) shouldBe Double::class + getUnifiedNumberClass(UShort::class, Short::class) shouldBe Int::class + getUnifiedNumberClass(UByte::class, Byte::class) shouldBe Short::class + + getUnifiedNumberClass(UByte::class, UShort::class) shouldBe UShort::class + + // Multi-level relationships + getUnifiedNumberClass(Byte::class, Int::class) shouldBe Int::class + getUnifiedNumberClass(UByte::class, Long::class) shouldBe Long::class + getUnifiedNumberClass(Short::class, Double::class) shouldBe Double::class + getUnifiedNumberClass(UInt::class, Int::class) shouldBe Long::class + + // Top-level types + getUnifiedNumberClass(BigDecimal::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClass(BigInteger::class, Long::class) shouldBe BigInteger::class + getUnifiedNumberClass(BigDecimal::class, BigInteger::class) shouldBe BigDecimal::class + + // Distant relationships + getUnifiedNumberClass(Byte::class, BigDecimal::class) shouldBe BigDecimal::class + getUnifiedNumberClass(UByte::class, Double::class) shouldBe Double::class + + // Complex type promotions + getUnifiedNumberClass(Int::class, Float::class) shouldBe Double::class + getUnifiedNumberClass(Long::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClass(ULong::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClass(BigInteger::class, Double::class) shouldBe BigDecimal::class + + // Edge case with null + getUnifiedNumberClass(null, Int::class) shouldBe Int::class + } }