Skip to content

Commit ed58e48

Browse files
authored
Merge pull request #1070 from Kotlin/number-types
Unified number types
2 parents 65ea1ee + 6458d84 commit ed58e48

File tree

9 files changed

+561
-40
lines changed

9 files changed

+561
-40
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dataframe.impl.zero
2525
import org.jetbrains.kotlinx.dataframe.math.sum
2626
import org.jetbrains.kotlinx.dataframe.math.sumOf
2727
import kotlin.reflect.KProperty
28+
import kotlin.reflect.full.isSubtypeOf
2829
import kotlin.reflect.typeOf
2930

3031
// region DataColumn
@@ -42,7 +43,11 @@ public inline fun <T, reified R : Number> DataColumn<T>.sumOf(crossinline expres
4243

4344
// region DataRow
4445

45-
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateMixed(values().filterIsInstance<Number>()) ?: 0
46+
public fun AnyRow.rowSum(): Number =
47+
Aggregators.sum.aggregateMixed(
48+
values = values().filterIsInstance<Number>(),
49+
types = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
50+
) ?: 0
4651

4752
public inline fun <reified T : Number> AnyRow.rowSumOf(): T = values().filterIsInstance<T>().sum(typeOf<T>())
4853

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package org.jetbrains.kotlinx.dataframe.documentation
2+
3+
/**
4+
* ## Unifying Numbers
5+
*
6+
* The concept of unifying numbers is converting them to a common number type without losing information.
7+
*
8+
* The following graph shows the hierarchy of number types in Kotlin DataFrame.
9+
* The order is top-down from the most complex type to the simplest one.
10+
*
11+
* {@include [Graph]}
12+
* For each number type in the graph, it holds that a number of that type can be expressed lossless by
13+
* a number of a more complex type (any of its parents).
14+
* This is either because the more complex type has a larger range or higher precision (in terms of bits).
15+
*/
16+
internal interface UnifyingNumbers {
17+
18+
/**
19+
* ```
20+
* BigDecimal
21+
* / \\
22+
* BigInteger \\
23+
* / \\ \\
24+
* ULong Long Double
25+
* .. | / | / | \\..
26+
* \\ | / | / |
27+
* UInt Int Float
28+
* .. | / | / \\..
29+
* \\ | / | /
30+
* UShort Short
31+
* | / |
32+
* | / |
33+
* UByte Byte
34+
* ```
35+
*/
36+
interface Graph
37+
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package org.jetbrains.kotlinx.dataframe.impl
2+
3+
import kotlin.experimental.ExperimentalTypeInference
4+
5+
/**
6+
* Represents a directed acyclic graph (DAG) of generic type [T].
7+
*
8+
* This class is immutable and guarantees that the graph does not contain any cycles.
9+
* It provides functionality to find the nearest common ancestor of two vertices
10+
* in the graph ([findNearestCommonVertex]).
11+
*
12+
* Use the [Builder] class or [buildDag] function to create a new instance of this class.
13+
*
14+
* @param T The type of items in the graph.
15+
* @property adjacencyList A map representing directed edges, where the keys are source vertices
16+
* and the values are sets of destination vertices.
17+
* @property vertices A set of all vertices in the graph.
18+
*/
19+
internal class DirectedAcyclicGraph<T> private constructor(
20+
private val adjacencyList: Map<T, Set<T>>,
21+
private val vertices: Set<T>,
22+
) {
23+
class Builder<T> {
24+
private val edges = mutableListOf<Pair<T, T>>()
25+
private val vertices = mutableSetOf<T>()
26+
27+
fun addEdge(from: T, to: T): Builder<T> {
28+
edges.add(from to to)
29+
vertices.add(from)
30+
vertices.add(to)
31+
return this
32+
}
33+
34+
fun addEdges(vararg edges: Pair<T, T>): Builder<T> {
35+
edges.forEach { (from, to) -> addEdge(from, to) }
36+
return this
37+
}
38+
39+
fun build(): DirectedAcyclicGraph<T> {
40+
val adjacencyList = edges.groupBy({ it.first }, { it.second })
41+
.mapValues { it.value.toSet() }
42+
43+
if (hasCycle(adjacencyList)) {
44+
throw IllegalStateException("Graph contains cycle")
45+
}
46+
47+
return DirectedAcyclicGraph(adjacencyList, vertices)
48+
}
49+
50+
private fun hasCycle(adjacencyList: Map<T, Set<T>>): Boolean {
51+
val visited = mutableSetOf<T>()
52+
val recursionStack = mutableSetOf<T>()
53+
54+
fun dfs(vertex: T): Boolean {
55+
if (vertex in recursionStack) return true
56+
if (vertex in visited) return false
57+
58+
visited.add(vertex)
59+
recursionStack.add(vertex)
60+
61+
adjacencyList[vertex]?.forEach { neighbor ->
62+
if (dfs(neighbor)) return true
63+
}
64+
65+
recursionStack.remove(vertex)
66+
return false
67+
}
68+
69+
return adjacencyList.keys.any { vertex ->
70+
if (vertex !in visited && dfs(vertex)) return true
71+
false
72+
}
73+
}
74+
}
75+
76+
fun findNearestCommonVertex(vertex1: T, vertex2: T): T? {
77+
if (vertex1 !in vertices || vertex2 !in vertices) return null
78+
if (vertex1 == vertex2) return vertex1
79+
80+
// Get all ancestors for both vertices
81+
val ancestors1 = getAllAncestors(vertex1)
82+
val ancestors2 = getAllAncestors(vertex2)
83+
84+
// If one vertex is an ancestor of another, return that vertex
85+
if (vertex1 in ancestors2) return vertex1
86+
if (vertex2 in ancestors1) return vertex2
87+
88+
// Find common ancestors
89+
val commonAncestors = ancestors1.intersect(ancestors2)
90+
if (commonAncestors.isEmpty()) return null
91+
92+
// Find the nearest common ancestor by checking distance from both vertices
93+
return commonAncestors.minByOrNull { ancestor ->
94+
getDistance(ancestor, vertex1) + getDistance(ancestor, vertex2)
95+
}
96+
}
97+
98+
private fun getAllAncestors(vertex: T): Set<T> {
99+
val ancestors = mutableSetOf<T>()
100+
val visited = mutableSetOf<T>()
101+
102+
fun dfs(current: T) {
103+
if (current in visited) return
104+
visited.add(current)
105+
106+
adjacencyList.forEach { (parent, children) ->
107+
if (current in children) {
108+
ancestors.add(parent)
109+
dfs(parent)
110+
}
111+
}
112+
}
113+
114+
dfs(vertex)
115+
return ancestors
116+
}
117+
118+
private fun getDistance(from: T, to: T): Int {
119+
if (from == to) return 0
120+
121+
val distances = mutableMapOf<T, Int>()
122+
val queue = ArrayDeque<T>()
123+
124+
queue.add(from)
125+
distances[from] = 0
126+
127+
while (queue.isNotEmpty()) {
128+
val current = queue.removeFirst()
129+
val currentDistance = distances[current] ?: continue
130+
131+
adjacencyList[current]?.forEach { neighbor ->
132+
if (neighbor !in distances) {
133+
distances[neighbor] = currentDistance + 1
134+
queue.add(neighbor)
135+
if (neighbor == to) return currentDistance + 1
136+
}
137+
}
138+
}
139+
140+
return Int.MAX_VALUE
141+
}
142+
143+
fun <R> map(conversion: (T) -> R): DirectedAcyclicGraph<R> {
144+
val cache = mutableMapOf<T, R>()
145+
val cachedConversion: (T) -> R = { cache.getOrPut(it) { conversion(it) } }
146+
147+
return Builder<R>().apply {
148+
for ((from, to) in adjacencyList) {
149+
for (to in to) {
150+
addEdge(from = cachedConversion(from), to = cachedConversion(to))
151+
}
152+
}
153+
}.build()
154+
}
155+
156+
companion object {
157+
fun <T> builder(): Builder<T> = Builder()
158+
}
159+
}
160+
161+
/**
162+
* Builds a new [DirectedAcyclicGraph] using the provided [builder] function.
163+
*
164+
* @see DirectedAcyclicGraph
165+
*/
166+
@OptIn(ExperimentalTypeInference::class)
167+
internal fun <T> buildDag(
168+
@BuilderInference builder: DirectedAcyclicGraph.Builder<T>.() -> Unit,
169+
): DirectedAcyclicGraph<T> = DirectedAcyclicGraph.builder<T>().apply(builder).build()
170+
171+
/**
172+
* Builds a new [DirectedAcyclicGraph] using the provided [edges].
173+
*
174+
* @see DirectedAcyclicGraph
175+
*/
176+
internal fun <T> dagOf(vararg edges: Pair<T, T>): DirectedAcyclicGraph<T> = buildDag { addEdges(*edges) }
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package org.jetbrains.kotlinx.dataframe.impl
2+
3+
import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers
4+
import org.jetbrains.kotlinx.dataframe.impl.api.createConverter
5+
import java.math.BigDecimal
6+
import java.math.BigInteger
7+
import kotlin.reflect.KClass
8+
import kotlin.reflect.KType
9+
import kotlin.reflect.full.withNullability
10+
import kotlin.reflect.typeOf
11+
12+
/**
13+
* Number type graph, structured in terms of number complexity.
14+
* A number can always be expressed lossless by a number of a more complex type (any of its parents).
15+
*
16+
* {@include [UnifyingNumbers.Graph]}
17+
*
18+
* For any two numbers, we can find the nearest common ancestor in this graph
19+
* by calling [DirectedAcyclicGraph.findNearestCommonVertex].
20+
* @see getUnifiedNumberClass
21+
* @see unifiedNumberClass
22+
* @see UnifyingNumbers
23+
*/
24+
internal val unifiedNumberTypeGraph: DirectedAcyclicGraph<KType> by lazy {
25+
buildDag {
26+
addEdge(typeOf<BigDecimal>(), typeOf<BigInteger>())
27+
addEdge(typeOf<BigDecimal>(), typeOf<Double>())
28+
29+
addEdge(typeOf<BigInteger>(), typeOf<ULong>())
30+
addEdge(typeOf<BigInteger>(), typeOf<Long>())
31+
32+
addEdge(typeOf<ULong>(), typeOf<UInt>())
33+
34+
addEdge(typeOf<Long>(), typeOf<UInt>())
35+
addEdge(typeOf<Long>(), typeOf<Int>())
36+
37+
addEdge(typeOf<Double>(), typeOf<Int>())
38+
addEdge(typeOf<Double>(), typeOf<Float>())
39+
addEdge(typeOf<Double>(), typeOf<UInt>())
40+
41+
addEdge(typeOf<UInt>(), typeOf<UShort>())
42+
43+
addEdge(typeOf<Int>(), typeOf<UShort>())
44+
addEdge(typeOf<Int>(), typeOf<Short>())
45+
46+
addEdge(typeOf<Float>(), typeOf<Short>())
47+
addEdge(typeOf<Float>(), typeOf<UShort>())
48+
49+
addEdge(typeOf<UShort>(), typeOf<UByte>())
50+
51+
addEdge(typeOf<Short>(), typeOf<UByte>())
52+
addEdge(typeOf<Short>(), typeOf<Byte>())
53+
}
54+
}
55+
56+
/** @include [unifiedNumberTypeGraph] */
57+
internal val unifiedNumberClassGraph: DirectedAcyclicGraph<KClass<*>> by lazy {
58+
unifiedNumberTypeGraph.map { it.classifier as KClass<*> }
59+
}
60+
61+
/**
62+
* Determines the nearest common numeric type, in terms of complexity, between two given classes/types.
63+
*
64+
* Unsigned types are supported too even though they are not a [Number] instance,
65+
* but unless two unsigned types are provided in the input, it will never be returned.
66+
* Meaning, a single [Number] input, the output will always be a [Number].
67+
*
68+
* @param first The first numeric type to compare. Can be null, in which case the second to is returned.
69+
* @param second The second numeric to compare. Cannot be null.
70+
* @return The nearest common numeric type between the two input classes.
71+
* If no common class is found, [IllegalStateException] is thrown.
72+
* @see UnifyingNumbers
73+
*/
74+
internal fun getUnifiedNumberType(first: KType?, second: KType): KType {
75+
if (first == null) return second
76+
77+
val firstWithoutNullability = first.withNullability(false)
78+
val secondWithoutNullability = second.withNullability(false)
79+
80+
val result = if (firstWithoutNullability == secondWithoutNullability) {
81+
firstWithoutNullability
82+
} else {
83+
unifiedNumberTypeGraph.findNearestCommonVertex(firstWithoutNullability, secondWithoutNullability)
84+
?: error("Can not find common number type for $first and $second")
85+
}
86+
87+
return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result
88+
}
89+
90+
/** @include [getUnifiedNumberType] */
91+
@Suppress("IntroduceWhenSubject")
92+
internal fun getUnifiedNumberClass(first: KClass<*>?, second: KClass<*>): KClass<*> =
93+
when {
94+
first == null -> second
95+
96+
first == second -> first
97+
98+
else -> unifiedNumberClassGraph.findNearestCommonVertex(first, second)
99+
?: error("Can not find common number type for $first and $second")
100+
}
101+
102+
/**
103+
* Determines the nearest common numeric type, in terms of complexity, all types in [this].
104+
*
105+
* Unsigned types are supported too even though they are not a [Number] instance,
106+
* but unless the input solely exists of unsigned numbers, it will never be returned.
107+
* Meaning, given a [Number] in the input, the output will always be a [Number].
108+
*
109+
* @return The nearest common numeric type between the input types.
110+
* If no common type is found, it returns [Number].
111+
* @see UnifyingNumbers
112+
*/
113+
internal fun Iterable<KType>.unifiedNumberType(): KType =
114+
fold(null as KType?, ::getUnifiedNumberType) ?: typeOf<Number>()
115+
116+
/** @include [unifiedNumberType] */
117+
internal fun Iterable<KClass<*>>.unifiedNumberClass(): KClass<*> =
118+
fold(null as KClass<*>?, ::getUnifiedNumberClass) ?: Number::class
119+
120+
/**
121+
* Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
122+
* The common numeric type is determined using the provided [commonNumberType] parameter
123+
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
124+
*
125+
* @param commonNumberType The desired common numeric type to convert the elements to.
126+
* This is determined by default using the types of the elements in the iterable.
127+
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
128+
* @throws IllegalStateException if an element cannot be converted to the common number type.
129+
* @see UnifyingNumbers
130+
*/
131+
@Suppress("UNCHECKED_CAST")
132+
internal fun Iterable<Number>.convertToUnifiedNumberType(
133+
commonNumberType: KType = this.types().unifiedNumberType(),
134+
): Iterable<Number> {
135+
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
136+
return map {
137+
converter(it) ?: error("Can not convert $it to $commonNumberType")
138+
}
139+
}

0 commit comments

Comments
 (0)