Skip to content

Unified number types #1070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.jetbrains.kotlinx.dataframe.impl.zero
import org.jetbrains.kotlinx.dataframe.math.sum
import org.jetbrains.kotlinx.dataframe.math.sumOf
import kotlin.reflect.KProperty
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.typeOf

// region DataColumn
Expand All @@ -42,7 +43,11 @@ public inline fun <T, reified R : Number> DataColumn<T>.sumOf(crossinline expres

// region DataRow

public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateMixed(values().filterIsInstance<Number>()) ?: 0
public fun AnyRow.rowSum(): Number =
Aggregators.sum.aggregateMixed(
values = values().filterIsInstance<Number>(),
types = columnTypes().filter { it.isSubtypeOf(typeOf<Number?>()) }.toSet(),
) ?: 0

public inline fun <reified T : Number> AnyRow.rowSumOf(): T = values().filterIsInstance<T>().sum(typeOf<T>())

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.jetbrains.kotlinx.dataframe.documentation

/**
* ## Unifying Numbers
*
* The concept of unifying numbers is converting them to a common number type without losing information.
*
* The following graph shows the hierarchy of number types in Kotlin DataFrame.
* The order is top-down from the most complex type to the simplest one.
*
* {@include [Graph]}
* For each number type in the graph, it holds that a number of that type can be expressed lossless by
* a number of a more complex type (any of its parents).
* This is either because the more complex type has a larger range or higher precision (in terms of bits).
*/
internal interface UnifyingNumbers {

/**
* ```
* BigDecimal
* / \\
* BigInteger \\
* / \\ \\
* ULong Long Double
* .. | / | / | \\..
* \\ | / | / |
* UInt Int Float
* .. | / | / \\..
* \\ | / | /
* UShort Short
* | / |
* | / |
* UByte Byte
* ```
*/
interface Graph
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
package org.jetbrains.kotlinx.dataframe.impl

import kotlin.experimental.ExperimentalTypeInference

/**
* Represents a directed acyclic graph (DAG) of generic type [T].
*
* This class is immutable and guarantees that the graph does not contain any cycles.
* It provides functionality to find the nearest common ancestor of two vertices
* in the graph ([findNearestCommonVertex]).
*
* Use the [Builder] class or [buildDag] function to create a new instance of this class.
*
* @param T The type of items in the graph.
* @property adjacencyList A map representing directed edges, where the keys are source vertices
* and the values are sets of destination vertices.
* @property vertices A set of all vertices in the graph.
*/
internal class DirectedAcyclicGraph<T> private constructor(
private val adjacencyList: Map<T, Set<T>>,
private val vertices: Set<T>,
) {
class Builder<T> {
private val edges = mutableListOf<Pair<T, T>>()
private val vertices = mutableSetOf<T>()

fun addEdge(from: T, to: T): Builder<T> {
edges.add(from to to)
vertices.add(from)
vertices.add(to)
return this
}

fun addEdges(vararg edges: Pair<T, T>): Builder<T> {
edges.forEach { (from, to) -> addEdge(from, to) }
return this
}

fun build(): DirectedAcyclicGraph<T> {
val adjacencyList = edges.groupBy({ it.first }, { it.second })
.mapValues { it.value.toSet() }

if (hasCycle(adjacencyList)) {
throw IllegalStateException("Graph contains cycle")
}

return DirectedAcyclicGraph(adjacencyList, vertices)
}

private fun hasCycle(adjacencyList: Map<T, Set<T>>): Boolean {
val visited = mutableSetOf<T>()
val recursionStack = mutableSetOf<T>()

fun dfs(vertex: T): Boolean {
if (vertex in recursionStack) return true
if (vertex in visited) return false

visited.add(vertex)
recursionStack.add(vertex)

adjacencyList[vertex]?.forEach { neighbor ->
if (dfs(neighbor)) return true
}

recursionStack.remove(vertex)
return false
}

return adjacencyList.keys.any { vertex ->
if (vertex !in visited && dfs(vertex)) return true
false
}
}
}

fun findNearestCommonVertex(vertex1: T, vertex2: T): T? {
if (vertex1 !in vertices || vertex2 !in vertices) return null
if (vertex1 == vertex2) return vertex1

// Get all ancestors for both vertices
val ancestors1 = getAllAncestors(vertex1)
val ancestors2 = getAllAncestors(vertex2)

// If one vertex is an ancestor of another, return that vertex
if (vertex1 in ancestors2) return vertex1
if (vertex2 in ancestors1) return vertex2

// Find common ancestors
val commonAncestors = ancestors1.intersect(ancestors2)
if (commonAncestors.isEmpty()) return null

// Find the nearest common ancestor by checking distance from both vertices
return commonAncestors.minByOrNull { ancestor ->
getDistance(ancestor, vertex1) + getDistance(ancestor, vertex2)
}
}

private fun getAllAncestors(vertex: T): Set<T> {
val ancestors = mutableSetOf<T>()
val visited = mutableSetOf<T>()

fun dfs(current: T) {
if (current in visited) return
visited.add(current)

adjacencyList.forEach { (parent, children) ->
if (current in children) {
ancestors.add(parent)
dfs(parent)
}
}
}

dfs(vertex)
return ancestors
}

private fun getDistance(from: T, to: T): Int {
if (from == to) return 0

val distances = mutableMapOf<T, Int>()
val queue = ArrayDeque<T>()

queue.add(from)
distances[from] = 0

while (queue.isNotEmpty()) {
val current = queue.removeFirst()
val currentDistance = distances[current] ?: continue

adjacencyList[current]?.forEach { neighbor ->
if (neighbor !in distances) {
distances[neighbor] = currentDistance + 1
queue.add(neighbor)
if (neighbor == to) return currentDistance + 1
}
}
}

return Int.MAX_VALUE
}

fun <R> map(conversion: (T) -> R): DirectedAcyclicGraph<R> {
val cache = mutableMapOf<T, R>()
val cachedConversion: (T) -> R = { cache.getOrPut(it) { conversion(it) } }

return Builder<R>().apply {
for ((from, to) in adjacencyList) {
for (to in to) {
addEdge(from = cachedConversion(from), to = cachedConversion(to))
}
}
}.build()
}

companion object {
fun <T> builder(): Builder<T> = Builder()
}
}

/**
* Builds a new [DirectedAcyclicGraph] using the provided [builder] function.
*
* @see DirectedAcyclicGraph
*/
@OptIn(ExperimentalTypeInference::class)
internal fun <T> buildDag(
@BuilderInference builder: DirectedAcyclicGraph.Builder<T>.() -> Unit,
): DirectedAcyclicGraph<T> = DirectedAcyclicGraph.builder<T>().apply(builder).build()

/**
* Builds a new [DirectedAcyclicGraph] using the provided [edges].
*
* @see DirectedAcyclicGraph
*/
internal fun <T> dagOf(vararg edges: Pair<T, T>): DirectedAcyclicGraph<T> = buildDag { addEdges(*edges) }
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package org.jetbrains.kotlinx.dataframe.impl

import org.jetbrains.kotlinx.dataframe.documentation.UnifyingNumbers
import org.jetbrains.kotlinx.dataframe.impl.api.createConverter
import java.math.BigDecimal
import java.math.BigInteger
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf

/**
* Number type graph, structured in terms of number complexity.
* A number can always be expressed lossless by a number of a more complex type (any of its parents).
*
* {@include [UnifyingNumbers.Graph]}
*
* For any two numbers, we can find the nearest common ancestor in this graph
* by calling [DirectedAcyclicGraph.findNearestCommonVertex].
* @see getUnifiedNumberClass
* @see unifiedNumberClass
* @see UnifyingNumbers
*/
internal val unifiedNumberTypeGraph: DirectedAcyclicGraph<KType> by lazy {
buildDag {
addEdge(typeOf<BigDecimal>(), typeOf<BigInteger>())
addEdge(typeOf<BigDecimal>(), typeOf<Double>())

addEdge(typeOf<BigInteger>(), typeOf<ULong>())
addEdge(typeOf<BigInteger>(), typeOf<Long>())

addEdge(typeOf<ULong>(), typeOf<UInt>())

addEdge(typeOf<Long>(), typeOf<UInt>())
addEdge(typeOf<Long>(), typeOf<Int>())

addEdge(typeOf<Double>(), typeOf<Int>())
addEdge(typeOf<Double>(), typeOf<Float>())
addEdge(typeOf<Double>(), typeOf<UInt>())

addEdge(typeOf<UInt>(), typeOf<UShort>())

addEdge(typeOf<Int>(), typeOf<UShort>())
addEdge(typeOf<Int>(), typeOf<Short>())

addEdge(typeOf<Float>(), typeOf<Short>())
addEdge(typeOf<Float>(), typeOf<UShort>())

addEdge(typeOf<UShort>(), typeOf<UByte>())

addEdge(typeOf<Short>(), typeOf<UByte>())
addEdge(typeOf<Short>(), typeOf<Byte>())
}
}

/** @include [unifiedNumberTypeGraph] */
internal val unifiedNumberClassGraph: DirectedAcyclicGraph<KClass<*>> by lazy {
unifiedNumberTypeGraph.map { it.classifier as KClass<*> }
}

/**
* Determines the nearest common numeric type, in terms of complexity, between two given classes/types.
*
* Unsigned types are supported too even though they are not a [Number] instance,
* but unless two unsigned types are provided in the input, it will never be returned.
* Meaning, a single [Number] input, the output will always be a [Number].
*
* @param first The first numeric type to compare. Can be null, in which case the second to is returned.
* @param second The second numeric to compare. Cannot be null.
* @return The nearest common numeric type between the two input classes.
* If no common class is found, [IllegalStateException] is thrown.
* @see UnifyingNumbers
*/
internal fun getUnifiedNumberType(first: KType?, second: KType): KType {
if (first == null) return second

val firstWithoutNullability = first.withNullability(false)
val secondWithoutNullability = second.withNullability(false)

val result = if (firstWithoutNullability == secondWithoutNullability) {
firstWithoutNullability
} else {
unifiedNumberTypeGraph.findNearestCommonVertex(firstWithoutNullability, secondWithoutNullability)
?: error("Can not find common number type for $first and $second")
}

return if (first.isMarkedNullable || second.isMarkedNullable) result.withNullability(true) else result
}

/** @include [getUnifiedNumberType] */
@Suppress("IntroduceWhenSubject")
internal fun getUnifiedNumberClass(first: KClass<*>?, second: KClass<*>): KClass<*> =
when {
first == null -> second

first == second -> first

else -> unifiedNumberClassGraph.findNearestCommonVertex(first, second)
?: error("Can not find common number type for $first and $second")
}

/**
* Determines the nearest common numeric type, in terms of complexity, all types in [this].
*
* Unsigned types are supported too even though they are not a [Number] instance,
* but unless the input solely exists of unsigned numbers, it will never be returned.
* Meaning, given a [Number] in the input, the output will always be a [Number].
*
* @return The nearest common numeric type between the input types.
* If no common type is found, it returns [Number].
* @see UnifyingNumbers
*/
internal fun Iterable<KType>.unifiedNumberType(): KType =
fold(null as KType?, ::getUnifiedNumberType) ?: typeOf<Number>()

/** @include [unifiedNumberType] */
internal fun Iterable<KClass<*>>.unifiedNumberClass(): KClass<*> =
fold(null as KClass<*>?, ::getUnifiedNumberClass) ?: Number::class

/**
* Converts the elements of the given iterable of numbers into a common numeric type based on complexity.
* The common numeric type is determined using the provided [commonNumberType] parameter
* or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified.
*
* @param commonNumberType The desired common numeric type to convert the elements to.
* This is determined by default using the types of the elements in the iterable.
* @return A new iterable of numbers where each element is converted to the specified or inferred common number type.
* @throws IllegalStateException if an element cannot be converted to the common number type.
* @see UnifyingNumbers
*/
@Suppress("UNCHECKED_CAST")
internal fun Iterable<Number>.convertToUnifiedNumberType(
commonNumberType: KType = this.types().unifiedNumberType(),
): Iterable<Number> {
val converter = createConverter(typeOf<Number>(), commonNumberType)!! as (Number) -> Number?
return map {
converter(it) ?: error("Can not convert $it to $commonNumberType")
}
}
Loading