Skip to content

Commit b4c4e36

Browse files
committed
added tests and fixed NaN behavior for min/max functions
1 parent 87ab8ed commit b4c4e36

File tree

11 files changed

+675
-56
lines changed

11 files changed

+675
-56
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
77
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
88
import org.jetbrains.kotlinx.dataframe.columns.FrameColumn
99
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
10+
import org.jetbrains.kotlinx.dataframe.impl.isIntraComparable
1011
import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber
1112
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber
1213
import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber
@@ -18,11 +19,7 @@ import org.jetbrains.kotlinx.dataframe.util.IS_INTER_COMPARABLE_IMPORT
1819
import kotlin.contracts.ExperimentalContracts
1920
import kotlin.contracts.contract
2021
import kotlin.reflect.KType
21-
import kotlin.reflect.KTypeProjection
22-
import kotlin.reflect.KVariance
23-
import kotlin.reflect.full.createType
2422
import kotlin.reflect.full.isSubtypeOf
25-
import kotlin.reflect.full.withNullability
2623
import kotlin.reflect.typeOf
2724

2825
public fun AnyCol.isColumnGroup(): Boolean {
@@ -87,13 +84,4 @@ public fun AnyCol.isComparable(): Boolean = valuesAreComparable()
8784
*
8885
* Technically, this means the values' common type `T(?)` is a subtype of [Comparable]`<in T>(?)`
8986
*/
90-
public fun AnyCol.valuesAreComparable(): Boolean =
91-
isValueColumn() &&
92-
isSubtypeOf(
93-
Comparable::class.createType(
94-
arguments = listOf(
95-
KTypeProjection(KVariance.IN, type().withNullability(false)),
96-
),
97-
nullable = hasNulls(),
98-
),
99-
)
87+
public fun AnyCol.valuesAreComparable(): Boolean = isValueColumn() && type().isIntraComparable()

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/max.kt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2121
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
2222
import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
2323
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
24+
import org.jetbrains.kotlinx.dataframe.util.ROW_MAX
25+
import org.jetbrains.kotlinx.dataframe.util.ROW_MAX_OR_NULL
2426
import kotlin.reflect.KProperty
2527

2628
// region DataColumn
@@ -55,14 +57,11 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.maxOfOrNul
5557

5658
// region DataRow
5759

58-
@Deprecated("", level = DeprecationLevel.ERROR)
59-
public fun AnyRow.rowMaxOrNull(): Any? =
60-
error("") // values().filterIsInstance<Comparable<*>>().maxWithOrNull(compareBy { it })
60+
@Deprecated(ROW_MAX_OR_NULL, level = DeprecationLevel.ERROR)
61+
public fun AnyRow.rowMaxOrNull(): Any? = error(ROW_MAX_OR_NULL)
6162

62-
@Deprecated("", level = DeprecationLevel.ERROR)
63-
public fun AnyRow.rowMax(): Any = error("") // rowMaxOrNull().suggestIfNull("rowMax")
64-
65-
// todo add rowMaxBy?
63+
@Deprecated(ROW_MAX, level = DeprecationLevel.ERROR)
64+
public fun AnyRow.rowMax(): Any = error(ROW_MAX)
6665

6766
public inline fun <reified T : Comparable<T & Any>?> AnyRow.rowMaxOfOrNull(skipNaN: Boolean = skipNaN_default): T? =
6867
Aggregators.max<T>(skipNaN).aggregateOfRow(this) { colsOf<T>() }
@@ -74,7 +73,6 @@ public inline fun <reified T : Comparable<T & Any>?> AnyRow.rowMaxOf(skipNaN: Bo
7473

7574
// region DataFrame
7675

77-
// TODO intraComparableOrNumber
7876
public fun <T> DataFrame<T>.max(skipNaN: Boolean = skipNaN_default): DataRow<T> =
7977
maxFor(skipNaN, intraComparableColumns())
8078

@@ -192,7 +190,6 @@ public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.maxByOrNull
192190

193191
// region GroupBy
194192

195-
// TODO intraComparableOrNumber
196193
@Refine
197194
@Interpretable("GroupByMax1")
198195
public fun <T> Grouped<T>.max(skipNaN: Boolean = skipNaN_default): DataFrame<T> =
@@ -361,7 +358,6 @@ public inline fun <T, reified C : Comparable<C & Any>?> Pivot<T>.maxBy(
361358

362359
// region PivotGroupBy
363360

364-
// TODO intraComparableOrNumber
365361
public fun <T> PivotGroupBy<T>.max(separate: Boolean = false, skipNaN: Boolean = skipNaN_default): DataFrame<T> =
366362
maxFor(separate, skipNaN, intraComparableColumns())
367363

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf
2121
import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow
2222
import org.jetbrains.kotlinx.dataframe.impl.columns.toComparableColumns
2323
import org.jetbrains.kotlinx.dataframe.impl.suggestIfNull
24+
import org.jetbrains.kotlinx.dataframe.util.ROW_MIN
25+
import org.jetbrains.kotlinx.dataframe.util.ROW_MIN_OR_NULL
2426
import kotlin.reflect.KProperty
2527

2628
// region DataColumn
@@ -55,14 +57,11 @@ public inline fun <T, reified R : Comparable<R & Any>?> DataColumn<T>.minOfOrNul
5557

5658
// region DataRow
5759

58-
@Deprecated("", level = DeprecationLevel.ERROR)
59-
public fun AnyRow.rowMinOrNull(): Any? =
60-
error("") // values().filterIsInstance<Comparable<*>>().minWithOrNull(compareBy { it })
60+
@Deprecated(ROW_MIN_OR_NULL, level = DeprecationLevel.ERROR)
61+
public fun AnyRow.rowMinOrNull(): Any? = error(ROW_MIN_OR_NULL)
6162

62-
@Deprecated("", level = DeprecationLevel.ERROR)
63-
public fun AnyRow.rowMin(): Any = error("") // rowMinOrNull().suggestIfNull("rowMin")
64-
65-
// todo add rowMinBy?
63+
@Deprecated(ROW_MIN, level = DeprecationLevel.ERROR)
64+
public fun AnyRow.rowMin(): Any = error(ROW_MIN)
6665

6766
public inline fun <reified T : Comparable<T & Any>?> AnyRow.rowMinOfOrNull(skipNaN: Boolean = skipNaN_default): T? =
6867
Aggregators.min<T>(skipNaN).aggregateOfRow(this) { colsOf<T>() }
@@ -74,7 +73,6 @@ public inline fun <reified T : Comparable<T & Any>?> AnyRow.rowMinOf(skipNaN: Bo
7473

7574
// region DataFrame
7675

77-
// TODO intraComparableOrNumber
7876
public fun <T> DataFrame<T>.min(skipNaN: Boolean = skipNaN_default): DataRow<T> =
7977
minFor(skipNaN, intraComparableColumns())
8078

@@ -192,7 +190,6 @@ public inline fun <T, reified C : Comparable<C & Any>?> DataFrame<T>.minByOrNull
192190

193191
// region GroupBy
194192

195-
// TODO intraComparableOrNumber
196193
@Refine
197194
@Interpretable("GroupByMin1")
198195
public fun <T> Grouped<T>.min(skipNaN: Boolean = skipNaN_default): DataFrame<T> =
@@ -361,7 +358,6 @@ public inline fun <T, reified C : Comparable<C & Any>?> Pivot<T>.minBy(
361358

362359
// region PivotGroupBy
363360

364-
// TODO intraComparableOrNumber
365361
public fun <T> PivotGroupBy<T>.min(separate: Boolean = false, skipNaN: Boolean = skipNaN_default): DataFrame<T> =
366362
minFor(separate, skipNaN, intraComparableColumns())
367363

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
99
import org.jetbrains.kotlinx.dataframe.DataFrame
1010
import org.jetbrains.kotlinx.dataframe.DataRow
1111
import org.jetbrains.kotlinx.dataframe.api.Infer
12+
import org.jetbrains.kotlinx.dataframe.api.isSubtypeOf
1213
import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType
1314
import org.jetbrains.kotlinx.dataframe.util.GUESS_VALUE_TYPE
1415
import java.math.BigDecimal
@@ -693,3 +694,16 @@ internal fun Iterable<Any?>.types(): Set<KType> =
693694
mapTo(mutableSetOf()) {
694695
if (it == null) nullableNothingType else it::class.createStarProjectedType(false)
695696
}
697+
698+
/**
699+
* Checks whether this KType adheres to `T : Comparable<T & Any>?`, aka, it is comparable with itself.
700+
*/
701+
internal fun KType.isIntraComparable(): Boolean =
702+
this.isSubtypeOf(
703+
Comparable::class.createType(
704+
arguments = listOf(
705+
KTypeProjection(IN, this.withNullability(false)),
706+
),
707+
nullable = this.isMarkedNullable,
708+
),
709+
)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCal
99
/**
1010
* Generic function to apply an [Aggregator] ([this]) to aggregate values of a row.
1111
*
12-
* [Aggregator.aggregateCalculatingType] is used to deal with mixed types.
12+
* [Aggregator.aggregateCalculatingValueType] is used to deal with mixed types.
1313
*
1414
* @param row a row to aggregate
1515
* @param columns selector of which columns inside the [row] to aggregate
@@ -21,7 +21,7 @@ internal fun <T, V : Any?, R : Any?> Aggregator<V & Any, R>.aggregateOfRow(
2121
): R {
2222
val filteredColumns = row.df().getColumns(columns).asSequence()
2323
return aggregateCalculatingValueType(
24-
values = filteredColumns.mapNotNull { row[it] },
24+
values = filteredColumns.map { row[it] },
2525
valueTypes = filteredColumns.map { it.type() }.toSet(),
2626
)
2727
}

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/max.kt

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,49 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.preserveRetu
55
import org.jetbrains.kotlinx.dataframe.impl.canBeNaN
66
import org.jetbrains.kotlinx.dataframe.impl.indexOfBestNotNaBy
77
import org.jetbrains.kotlinx.dataframe.impl.indexOfBestNotNullBy
8+
import org.jetbrains.kotlinx.dataframe.impl.isIntraComparable
89
import org.jetbrains.kotlinx.dataframe.impl.renderType
910
import kotlin.reflect.KType
11+
import kotlin.reflect.typeOf
1012

1113
@Suppress("UNCHECKED_CAST")
1214
@PublishedApi
1315
internal fun <T : Comparable<T>> Sequence<T>.maxOrNull(type: KType, skipNaN: Boolean): T? {
1416
if (type.isMarkedNullable) {
1517
error("Encountered nullable type ${renderType(type)} in max function. This should not occur.")
1618
}
17-
return if (skipNaN && type.canBeNaN) {
18-
filterNot { it.isNaN }
19-
} else {
20-
this
21-
}.maxOrNull()
19+
if (!type.isIntraComparable()) {
20+
error(
21+
"Encountered non-comparable type ${
22+
renderType(type)
23+
} in max function. Only self-comparable types are supported.",
24+
)
25+
}
26+
27+
return when {
28+
// filter out NaNs if asked
29+
skipNaN && type.canBeNaN -> this.filterNot { it.isNaN }.maxOrNull()
30+
31+
// make sure that NaN is returned if it's in the sequence
32+
type == typeOf<Float>() -> (this as Sequence<Float>).maxOrNull() as T?
33+
34+
// make sure that NaN is returned if it's in the sequence
35+
type == typeOf<Double>() -> (this as Sequence<Double>).maxOrNull() as T?
36+
37+
else -> this.maxOrNull()
38+
}
2239
}
2340

41+
@Suppress("UNCHECKED_CAST")
2442
internal fun <C : Comparable<C>> Sequence<C?>.indexOfMax(type: KType, skipNaN: Boolean): Int =
25-
if (skipNaN && type.canBeNaN) {
26-
indexOfBestNotNaBy { this > it }
27-
} else {
28-
indexOfBestNotNullBy { this > it }
43+
when {
44+
// filter out NaNs if requested
45+
skipNaN && type.canBeNaN -> indexOfBestNotNaBy { this > it }
46+
47+
// make sure the index of the first NaN is returned if it's in the sequence
48+
type.canBeNaN -> indexOfBestNotNullBy { this.isNaN || (!it.isNaN && this > it) }
49+
50+
else -> indexOfBestNotNullBy { this > it }
2951
}
3052

3153
/** T: Comparable<T> -> T(?) */

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/min.kt

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,27 +5,50 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.preserveRetu
55
import org.jetbrains.kotlinx.dataframe.impl.canBeNaN
66
import org.jetbrains.kotlinx.dataframe.impl.indexOfBestNotNaBy
77
import org.jetbrains.kotlinx.dataframe.impl.indexOfBestNotNullBy
8+
import org.jetbrains.kotlinx.dataframe.impl.isIntraComparable
89
import org.jetbrains.kotlinx.dataframe.impl.renderType
910
import kotlin.reflect.KType
11+
import kotlin.reflect.typeOf
12+
import kotlin.text.compareTo
1013

1114
@Suppress("UNCHECKED_CAST")
1215
@PublishedApi
1316
internal fun <T : Comparable<T>> Sequence<T>.minOrNull(type: KType, skipNaN: Boolean): T? {
1417
if (type.isMarkedNullable) {
1518
error("Encountered nullable type ${renderType(type)} in min function. This should not occur.")
1619
}
17-
return if (skipNaN && type.canBeNaN) {
18-
filterNot { it.isNaN }
19-
} else {
20-
this
21-
}.minOrNull()
20+
if (!type.isIntraComparable()) {
21+
error(
22+
"Encountered non-comparable type ${
23+
renderType(type)
24+
} in min function. Only self-comparable types are supported.",
25+
)
26+
}
27+
28+
return when {
29+
// filter out NaNs if requested
30+
skipNaN && type.canBeNaN -> this.filterNot { it.isNaN }.minOrNull()
31+
32+
// make sure that NaN is returned if it's in the sequence
33+
type == typeOf<Float>() -> (this as Sequence<Float>).minOrNull() as T?
34+
35+
// make sure that NaN is returned if it's in the sequence
36+
type == typeOf<Double>() -> (this as Sequence<Double>).minOrNull() as T?
37+
38+
else -> this.minOrNull()
39+
}
2240
}
2341

42+
@Suppress("UNCHECKED_CAST")
2443
internal fun <C : Comparable<C>> Sequence<C?>.indexOfMin(type: KType, skipNaN: Boolean): Int =
25-
if (skipNaN && type.canBeNaN) {
26-
indexOfBestNotNaBy { this < it }
27-
} else {
28-
indexOfBestNotNullBy { this < it }
44+
when {
45+
// filter out NaNs if requested
46+
skipNaN && type.canBeNaN -> indexOfBestNotNaBy { this < it }
47+
48+
// make sure the index of the first NaN is returned if it's in the sequence
49+
type.canBeNaN -> indexOfBestNotNullBy { this.isNaN || (!it.isNaN && this < it) }
50+
51+
else -> indexOfBestNotNullBy { this < it }
2952
}
3053

3154
/** T: Comparable<T> -> T(?) */

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/util/deprecationMessages.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ internal const val TO_LEFT_REPLACE = "this.toStart()"
7474
internal const val TO_RIGHT = "This `toRight` overload will be removed in favor of `toEnd`. $MESSAGE_0_16"
7575
internal const val TO_RIGHT_REPLACE = "this.toEnd()"
7676

77+
internal const val ROW_MIN = "`rowMin` is deprecated in favor of `rowMinOf`. $MESSAGE_0_16"
78+
internal const val ROW_MIN_OR_NULL = "`rowMinOrNull` is deprecated in favor of `rowMinOfOrNull`. $MESSAGE_0_16"
79+
80+
internal const val ROW_MAX = "`rowMax` is deprecated in favor of `rowMaxOf`. $MESSAGE_0_16"
81+
internal const val ROW_MAX_OR_NULL = "`rowMaxOrNull` is deprecated in favor of `rowMaxOfOrNull`. $MESSAGE_0_16"
82+
7783
// endregion
7884

7985
// region WARNING in 0.16, ERROR in 0.17

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/describe.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ class DescribeTests {
6262
nulls shouldBe 0
6363
top shouldBe 1
6464
freq shouldBe 1
65-
this.mean.shouldBeNaN()
65+
mean.shouldBeNaN()
6666
std.shouldBeNaN()
67-
min shouldBe 1.0 // TODO should be NaN too?
67+
min.isNaN shouldBe true
6868
p25 shouldBe 1.75
6969
median shouldBe 3.0
7070
p75.isNaN shouldBe true

0 commit comments

Comments
 (0)