Skip to content

Commit 015eb99

Browse files
committed
Fixed getCommonNumberType and commonNumberClass functions that are used solely by sum at the moment. I introduced a future-proof rewrite of the function and added support for unsigned- and big numbers. Moved it to a separate file and added tests. We will reuse this logic in more places later. NumbersAggregator now converts numbers in its input to a common number type before aggregating, not relying on smart-casts anymore. To avoid heavy reflection calls, types can be supplied to aggregateMixed() if you're aware of them.
1 parent 6bf0a3f commit 015eb99

File tree

9 files changed

+524
-41
lines changed

9 files changed

+524
-41
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dataframe.impl.zero
2525
import org.jetbrains.kotlinx.dataframe.math.sum
2626
import org.jetbrains.kotlinx.dataframe.math.sumOf
2727
import kotlin.reflect.KProperty
28+
import kotlin.reflect.full.isSubtypeOf
2829
import kotlin.reflect.typeOf
2930

3031
// region DataColumn
@@ -42,7 +43,11 @@ public inline fun <T, reified R : Number> DataColumn<T>.sumOf(crossinline expres
4243

4344
// region DataRow
4445

45-
public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateMixed(values().filterIsInstance<Number>()) ?: 0
46+
public fun AnyRow.rowSum(): Number =
47+
Aggregators.sum.aggregateMixed(
48+
values = values().filterIsInstance<Number>(),
49+
types = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
50+
) ?: 0
4651

4752
public inline fun <reified T : Number> AnyRow.rowSumOf(): T = values().filterIsInstance<T>().sum(typeOf<T>())
4853

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package org.jetbrains.kotlinx.dataframe.impl
2+
3+
import kotlin.experimental.ExperimentalTypeInference
4+
5+
/**
6+
* Represents a directed acyclic graph (DAG) of generic type [T].
7+
*
8+
* This class is immutable and guarantees that the graph does not contain any cycles.
9+
* It provides functionality to find the nearest common ancestor of two vertices
10+
* in the graph ([findNearestCommonVertex]).
11+
*
12+
* Use the [Builder] class or [buildDag] function to create a new instance of this class.
13+
*
14+
* @param T The type of items in the graph.
15+
* @property adjacencyList A map representing directed edges, where the keys are source vertices
16+
* and the values are sets of destination vertices.
17+
* @property vertices A set of all vertices in the graph.
18+
*/
19+
internal class DirectedAcyclicGraph<T> private constructor(
20+
private val adjacencyList: Map<T, Set<T>>,
21+
private val vertices: Set<T>,
22+
) {
23+
class Builder<T> {
24+
private val edges = mutableListOf<Pair<T, T>>()
25+
private val vertices = mutableSetOf<T>()
26+
27+
fun addEdge(from: T, to: T): Builder<T> {
28+
edges.add(from to to)
29+
vertices.add(from)
30+
vertices.add(to)
31+
return this
32+
}
33+
34+
fun addEdges(vararg edges: Pair<T, T>): Builder<T> {
35+
edges.forEach { (from, to) -> addEdge(from, to) }
36+
return this
37+
}
38+
39+
fun build(): DirectedAcyclicGraph<T> {
40+
val adjacencyList = edges.groupBy({ it.first }, { it.second })
41+
.mapValues { it.value.toSet() }
42+
43+
if (hasCycle(adjacencyList)) {
44+
throw IllegalStateException("Graph contains cycle")
45+
}
46+
47+
return DirectedAcyclicGraph(adjacencyList, vertices)
48+
}
49+
50+
private fun hasCycle(adjacencyList: Map<T, Set<T>>): Boolean {
51+
val visited = mutableSetOf<T>()
52+
val recursionStack = mutableSetOf<T>()
53+
54+
fun dfs(vertex: T): Boolean {
55+
if (vertex in recursionStack) return true
56+
if (vertex in visited) return false
57+
58+
visited.add(vertex)
59+
recursionStack.add(vertex)
60+
61+
adjacencyList[vertex]?.forEach { neighbor ->
62+
if (dfs(neighbor)) return true
63+
}
64+
65+
recursionStack.remove(vertex)
66+
return false
67+
}
68+
69+
return adjacencyList.keys.any { vertex ->
70+
if (vertex !in visited && dfs(vertex)) return true
71+
false
72+
}
73+
}
74+
}
75+
76+
fun findNearestCommonVertex(vertex1: T, vertex2: T): T? {
77+
if (vertex1 !in vertices || vertex2 !in vertices) return null
78+
if (vertex1 == vertex2) return vertex1
79+
80+
// Get all ancestors for both vertices
81+
val ancestors1 = getAllAncestors(vertex1)
82+
val ancestors2 = getAllAncestors(vertex2)
83+
84+
// If one vertex is an ancestor of another, return that vertex
85+
if (vertex1 in ancestors2) return vertex1
86+
if (vertex2 in ancestors1) return vertex2
87+
88+
// Find common ancestors
89+
val commonAncestors = ancestors1.intersect(ancestors2)
90+
if (commonAncestors.isEmpty()) return null
91+
92+
// Find the nearest common ancestor by checking distance from both vertices
93+
return commonAncestors.minByOrNull { ancestor ->
94+
getDistance(ancestor, vertex1) + getDistance(ancestor, vertex2)
95+
}
96+
}
97+
98+
private fun getAllAncestors(vertex: T): Set<T> {
99+
val ancestors = mutableSetOf<T>()
100+
val visited = mutableSetOf<T>()
101+
102+
fun dfs(current: T) {
103+
if (current in visited) return
104+
visited.add(current)
105+
106+
adjacencyList.forEach { (parent, children) ->
107+
if (current in children) {
108+
ancestors.add(parent)
109+
dfs(parent)
110+
}
111+
}
112+
}
113+
114+
dfs(vertex)
115+
return ancestors
116+
}
117+
118+
private fun getDistance(from: T, to: T): Int {
119+
if (from == to) return 0
120+
121+
val distances = mutableMapOf<T, Int>()
122+
val queue = ArrayDeque<T>()
123+
124+
queue.add(from)
125+
distances[from] = 0
126+
127+
while (queue.isNotEmpty()) {
128+
val current = queue.removeFirst()
129+
val currentDistance = distances[current] ?: continue
130+
131+
adjacencyList[current]?.forEach { neighbor ->
132+
if (neighbor !in distances) {
133+
distances[neighbor] = currentDistance + 1
134+
queue.add(neighbor)
135+
if (neighbor == to) return currentDistance + 1
136+
}
137+
}
138+
}
139+
140+
return Int.MAX_VALUE
141+
}
142+
143+
fun <R> map(conversion: (T) -> R): DirectedAcyclicGraph<R> {
144+
val cache = mutableMapOf<T, R>()
145+
val cachedConversion: (T) -> R = { cache.getOrPut(it) { conversion(it) } }
146+
147+
return Builder<R>().apply {
148+
for ((from, to) in adjacencyList) {
149+
for (to in to) {
150+
addEdge(from = cachedConversion(from), to = cachedConversion(to))
151+
}
152+
}
153+
}.build()
154+
}
155+
156+
companion object {
157+
fun <T> builder(): Builder<T> = Builder()
158+
}
159+
}
160+
161+
/**
162+
* Builds a new [DirectedAcyclicGraph] using the provided [builder] function.
163+
*
164+
* @see DirectedAcyclicGraph
165+
*/
166+
@OptIn(ExperimentalTypeInference::class)
167+
internal fun <T> buildDag(
168+
@BuilderInference builder: DirectedAcyclicGraph.Builder<T>.() -> Unit,
169+
): DirectedAcyclicGraph<T> = DirectedAcyclicGraph.builder<T>().apply(builder).build()
170+
171+
/**
172+
* Builds a new [DirectedAcyclicGraph] using the provided [edges].
173+
*
174+
* @see DirectedAcyclicGraph
175+
*/
176+
internal fun <T> dagOf(vararg edges: Pair<T, T>): DirectedAcyclicGraph<T> = buildDag { addEdges(*edges) }
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
package org.jetbrains.kotlinx.dataframe.impl
2+
3+
import org.jetbrains.kotlinx.dataframe.impl.api.createConverter
4+
import org.jetbrains.kotlinx.dataframe.impl.commonNumberType
5+
import java.math.BigDecimal
6+
import java.math.BigInteger
7+
import kotlin.reflect.KClass
8+
import kotlin.reflect.KType
9+
import kotlin.reflect.full.withNullability
10+
import kotlin.reflect.typeOf
11+
12+
/**
13+
* Number type graph, structured in terms of number complexity.
14+
* A number can always be expressed lossless by a number of a more complex type (any of its parents).
15+
*
16+
* ```
17+
* BigDecimal
18+
* / \\
19+
* BigInteger |
20+
* | |
21+
* ULong |
22+
* | |
23+
* Long Double
24+
* \\ / |
25+
* UInt |
26+
* | |
27+
* Int Float
28+
* | /
29+
* UShort
30+
* |
31+
* Short
32+
* |
33+
* UByte
34+
* |
35+
* Byte
36+
* ```
37+
*
38+
* For any two numbers, we can find the nearest common ancestor in this graph
39+
* by calling [DirectedAcyclicGraph.findNearestCommonVertex].
40+
* @see getCommonNumberClass
41+
* @see commonNumberClass
42+
*/
43+
internal val numberTypeGraph: DirectedAcyclicGraph<KType> by lazy {
44+
dagOf(
45+
typeOf<BigDecimal>() to typeOf<BigInteger>(),
46+
typeOf<BigDecimal>() to typeOf<Double>(),
47+
typeOf<BigInteger>() to typeOf<ULong>(),
48+
typeOf<ULong>() to typeOf<Long>(),
49+
typeOf<Long>() to typeOf<UInt>(),
50+
typeOf<Double>() to typeOf<UInt>(),
51+
typeOf<Double>() to typeOf<Float>(),
52+
typeOf<UInt>() to typeOf<Int>(),
53+
typeOf<Int>() to typeOf<UShort>(),
54+
typeOf<Float>() to typeOf<UShort>(),
55+
typeOf<UShort>() to typeOf<Short>(),
56+
typeOf<Short>() to typeOf<UByte>(),
57+
typeOf<UByte>() to typeOf<Byte>(),
58+
)
59+
}
60+
61+
/** @include [numberTypeGraph] */
62+
internal val numberClassGraph: DirectedAcyclicGraph<KClass<*>> by lazy {
63+
numberTypeGraph.map { it.classifier as KClass<*> }
64+
}
65+
66+
/**
67+
* Determines the nearest common numeric type, in terms of complexity, between two given classes/types.
68+
*
69+
* Unsigned types are supported too even though they are not a [Number] instance,
70+
* but unless an unsigned type is provided in the input, it will never be returned.
71+
* Meaning, given two [Number] inputs, the output will always be a [Number].
72+
*
73+
* @param first The first numeric type to compare. Can be null, in which case the second to is returned.
74+
* @param second The second numeric to compare. Cannot be null.
75+
* @return The nearest common numeric type between the two input classes.
76+
* If no common class is found, [IllegalStateException] is thrown.
77+
* @see numberTypeGraph
78+
*/
79+
internal fun getCommonNumberType(first: KType?, second: KType): KType {
80+
if (first == null) return second
81+
82+
val firstWithoutNullability = first.withNullability(false)
83+
val secondWithoutNullability = second.withNullability(false)
84+
85+
val result = if (firstWithoutNullability == secondWithoutNullability) {
86+
firstWithoutNullability
87+
} else {
88+
numberTypeGraph.findNearestCommonVertex(firstWithoutNullability, secondWithoutNullability)
89+
?: error("Can not find common number type for $first and $second")
90+
}
91+
92+
return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result
93+
}
94+
95+
/** @include [getCommonNumberType] */
96+
@Suppress("IntroduceWhenSubject")
97+
internal fun getCommonNumberClass(first: KClass<*>?, second: KClass<*>): KClass<*> =
98+
when {
99+
first == null -> second
100+
101+
first == second -> first
102+
103+
else -> numberClassGraph.findNearestCommonVertex(first, second)
104+
?: error("Can not find common number type for $first and $second")
105+
}
106+
107+
/**
108+
* Determines the nearest common numeric type, in terms of complexity, all types in [this].
109+
*
110+
* Unsigned types are supported too even though they are not a [Number] instance,
111+
* but unless an unsigned type is provided in the input, it will never be returned.
112+
* Meaning, given just [Number] inputs, the output will always be a [Number].
113+
*
114+
* @return The nearest common numeric type between the input types.
115+
* If no common type is found, it returns [Number].
116+
* @see numberTypeGraph
117+
*/
118+
internal fun Iterable<KType>.commonNumberType(): KType = fold(null as KType?, ::getCommonNumberType) ?: typeOf<Number>()
119+
120+
/** @include [commonNumberType] */
121+
internal fun Iterable<KClass<*>>.commonNumberClass(): KClass<*> =
122+
fold(null as KClass<*>?, ::getCommonNumberClass) ?: Number::class
123+
124+
/**
125+
* Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
126+
* The common numeric type is determined using the provided [commonNumberType] parameter
127+
* or calculated with [Iterable.commonNumberType] from the iterable's elements if not explicitly specified.
128+
*
129+
* @param commonNumberType The desired common numeric type to convert the elements to.
130+
* This is determined by default using the types of the elements in the iterable.
131+
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
132+
* @throws IllegalStateException if an element cannot be converted to the common number type.
133+
* @see Iterable.commonNumberType
134+
*/
135+
@Suppress("UNCHECKED_CAST")
136+
internal fun Iterable<Number>.convertToCommonNumberType(
137+
commonNumberType: KType = this.types().commonNumberType(),
138+
): Iterable<Number> {
139+
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
140+
return map {
141+
converter(it) ?: error("Can not convert $it to $commonNumberType")
142+
}
143+
}

0 commit comments

Comments
 (0)