Skip to content

Commit c01e4ae

Browse files
committed
finished types and aggregateBy for min
1 parent 3629b35 commit c01e4ae

File tree

8 files changed

+244
-131
lines changed

8 files changed

+244
-131
lines changed

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/groupBy.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,5 @@ public class ReducedGroupBy<T, G>(
109109
override fun toString(): String = "ReducedGroupBy(groupBy=$groupBy, reducer=$reducer)"
110110
}
111111

112+
@PublishedApi
112113
internal fun <T, G> GroupBy<T, G>.reduce(reducer: Selector<DataFrame<G>, DataRow<G>?>) = ReducedGroupBy(this, reducer)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/min.kt

Lines changed: 113 additions & 86 deletions
Large diffs are not rendered by default.

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/pivot.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ public class ReducedPivot<T>(
243243
override fun toString(): String = "ReducedPivot(pivot=$pivot, reducer=$reducer)"
244244
}
245245

246+
@PublishedApi
246247
internal fun <T> Pivot<T>.reduce(reducer: Selector<DataFrame<T>, DataRow<T>?>) = ReducedPivot(this, reducer)
247248

248249
@PublishedApi

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,17 +105,19 @@ internal object Aggregators {
105105
) = AggregatorOptionSwitch2.Factory(getAggregator)
106106

107107
// T: Comparable<T> -> T?
108-
fun <T : Comparable<T>> min() = min.cast2<T, T?>()
108+
// T : Comparable<T & Any>? -> T?
109+
fun <T : Comparable<T & Any>?> min(): Aggregator<T & Any, T?> = min.cast2()
109110

110-
private val min by twoStepSelecting<Comparable<Any?>, Comparable<Any?>?>(
111+
private val min by twoStepSelecting<Comparable<Any>, Comparable<Any>?>(
111112
reducer = { type -> minOrNull(type) },
112113
indexOfResult = { indexOfMin() },
113114
)
114115

115116
// T: Comparable<T> -> T?
116-
fun <T : Comparable<T>> max() = max.cast2<T, T?>()
117+
// T : Comparable<T & Any>? -> T?
118+
fun <T : Comparable<T & Any>?> max(): Aggregator<T & Any, T?> = max.cast2()
117119

118-
private val max by twoStepSelecting<Comparable<Any?>, Comparable<Any?>?>(
120+
private val max by twoStepSelecting<Comparable<Any>, Comparable<Any>?>(
119121
reducer = { type -> maxOrNull(type) },
120122
indexOfResult = { indexOfMax() },
121123
)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/aggregateBy.kt

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@ import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.DataFrame
55
import org.jetbrains.kotlinx.dataframe.DataFrameExpression
66
import org.jetbrains.kotlinx.dataframe.DataRow
7+
import org.jetbrains.kotlinx.dataframe.RowExpression
78
import org.jetbrains.kotlinx.dataframe.annotations.CandidateForRemoval
89
import org.jetbrains.kotlinx.dataframe.api.GroupBy
910
import org.jetbrains.kotlinx.dataframe.api.Grouped
11+
import org.jetbrains.kotlinx.dataframe.api.asSequence
1012
import org.jetbrains.kotlinx.dataframe.api.cast
13+
import org.jetbrains.kotlinx.dataframe.api.getOrNull
14+
import org.jetbrains.kotlinx.dataframe.columns.ColumnReference
1115
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregateInternal
1216
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator
1317
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.indexOfAggregationResult
@@ -27,41 +31,88 @@ internal fun <T> Grouped<T>.aggregateByOrNull(body: DataFrameExpression<T, DataR
2731
}
2832

2933
/**
30-
* Aggregates the values of the column using the provided [Aggregator] `by` the provided [selector].
34+
* Selects the best matching value in the [sequence][values]
35+
* using the provided [Aggregator] `by` the provided [selector].
36+
*
37+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
3138
*/
3239
@Suppress("UNCHECKED_CAST")
3340
@PublishedApi
3441
internal inline fun <T, reified V : R, R : Any?> Aggregator<V & Any, R>.aggregateByOrNull(
3542
values: Sequence<T>,
36-
noinline selector: (T) -> V,
43+
crossinline selector: (T) -> V,
3744
): T? =
3845
values.elementAtOrNull(
3946
indexOfAggregationResult(
40-
values = values.map(selector),
47+
values = values.map { selector(it) },
4148
valueType = typeOf<V>(),
4249
),
4350
)
4451

4552
/**
46-
* Aggregates the values of the column using the provided [Aggregator] `by` the provided [selector].
53+
* Selects the best matching value in the [iterable][values]
54+
* using the provided [Aggregator] `by` the provided [selector].
4755
*
4856
* Faster implementation than for sequences.
57+
*
58+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
4959
*/
5060
@Suppress("UNCHECKED_CAST")
5161
@PublishedApi
52-
internal inline fun <T, reified V : R, R> Aggregator<V & Any, R>.aggregateByOrNull(
62+
internal inline fun <T, reified V : R, R : Any?> Aggregator<V & Any, R>.aggregateByOrNull(
5363
values: Iterable<T>,
54-
noinline selector: (T) -> V,
64+
crossinline selector: (T) -> V,
5565
): T? =
5666
values.elementAtOrNull(
5767
indexOfAggregationResult(
58-
values = values.asSequence().map(selector),
68+
values = values.asSequence().map { selector(it) },
5969
valueType = typeOf<V>(),
6070
),
6171
)
6272

73+
/**
74+
* Selects the best matching value in the [column] using the provided [Aggregator] `by` the provided [selector].
75+
*
76+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
77+
*/
6378
@PublishedApi
64-
internal inline fun <T, reified V : R, R> Aggregator<V & Any, R>.aggregateByOrNull(
79+
internal inline fun <T, reified V : R, R : Any?> Aggregator<V & Any, R>.aggregateByOrNull(
6580
column: DataColumn<T>,
66-
noinline selector: (T) -> V,
81+
crossinline selector: (T) -> V,
6782
): T? = aggregateByOrNull(column.values(), selector)
83+
84+
/**
85+
* Selects the best matching value in the [dataframe][data]
86+
* using the provided [Aggregator] `by` the provided [rowExpression].
87+
*
88+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
89+
*/
90+
@PublishedApi
91+
internal inline fun <T, reified V : R, R : Any?> Aggregator<V & Any, R>.aggregateByOrNull(
92+
data: DataFrame<T>,
93+
crossinline rowExpression: RowExpression<T, V>,
94+
): DataRow<T>? =
95+
data.getOrNull(
96+
indexOfAggregationResult(
97+
values = data.asSequence().map { rowExpression(it, it) },
98+
valueType = typeOf<V>(),
99+
),
100+
)
101+
102+
/**
103+
* Selects the best matching value in the [dataframe][data]
104+
* using the provided [Aggregator] `by` the provided [column].
105+
*
106+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
107+
*/
108+
@PublishedApi
109+
internal inline fun <T, reified V : R, R> Aggregator<V & Any, R>.aggregateByOrNull(
110+
data: DataFrame<T>,
111+
column: ColumnReference<V>,
112+
): DataRow<T>? =
113+
data.getOrNull(
114+
indexOfAggregationResult(
115+
values = data.asSequence().map { it[column] },
116+
valueType = typeOf<V>(),
117+
),
118+
)

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/ofRowExpression.kt

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes
22

33
import org.jetbrains.kotlinx.dataframe.DataColumn
44
import org.jetbrains.kotlinx.dataframe.DataFrame
5-
import org.jetbrains.kotlinx.dataframe.DataRow
65
import org.jetbrains.kotlinx.dataframe.RowExpression
76
import org.jetbrains.kotlinx.dataframe.aggregation.AggregateBody
87
import org.jetbrains.kotlinx.dataframe.api.Grouped
@@ -18,24 +17,42 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal
1817
import org.jetbrains.kotlinx.dataframe.impl.emptyPath
1918
import kotlin.reflect.typeOf
2019

20+
/**
21+
* Aggregates [values] by first applying [transform] to each element of the sequence and then
22+
* applying the [Aggregator] ([this]) to the resulting sequence.
23+
*
24+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
25+
*/
2126
@PublishedApi
2227
internal inline fun <C, reified V : Any?, R : Any?> Aggregator<V & Any, R>.aggregateOf(
2328
values: Sequence<C>,
2429
crossinline transform: (C) -> V,
25-
): R = aggregate(values = values.mapNotNull { transform(it) }, valueType = typeOf<V>())
30+
): R = aggregate(values = values.map { transform(it) }, valueType = typeOf<V>())
2631

32+
/**
33+
* Aggregates [column] by first applying [transform] to each element of the column and then
34+
* applying the [Aggregator] ([this]) to the resulting sequence.
35+
*
36+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
37+
*/
2738
@PublishedApi
2839
internal inline fun <C, reified V : Any?, R : Any?> Aggregator<V & Any, R>.aggregateOf(
2940
column: DataColumn<C>,
3041
crossinline transform: (C) -> V,
3142
): R = aggregateOf(column.asSequence(), transform)
3243

44+
/**
45+
* Aggregates [frame] by first applying [expression] to each row of the frame and then
46+
* applying the [Aggregator] ([this]) to the resulting sequence.
47+
*
48+
* @param V is used to infer whether there are nulls in the values fed to the aggregator!
49+
*/
3350
@Suppress("UNCHECKED_CAST")
3451
@PublishedApi
3552
internal inline fun <T, reified V : Any?, R : Any?> Aggregator<*, R>.aggregateOf(
3653
frame: DataFrame<T>,
3754
crossinline expression: RowExpression<T, V>,
38-
): R = (this as Aggregator<V & Any, R>).aggregateOf<DataRow<T>, V, R>(frame.rows().asSequence()) { expression(it, it) }
55+
): R = (this as Aggregator<V & Any, R>).aggregateOf(frame.rows().asSequence()) { expression(it, it) }
3956

4057
@PublishedApi
4158
internal fun <T, C, R : Any?> Aggregator<*, R>.aggregateOfDelegated(
@@ -47,26 +64,37 @@ internal fun <T, C, R : Any?> Aggregator<*, R>.aggregateOfDelegated(
4764
body(this, this)
4865
}
4966

67+
/**
68+
* Aggregates [data] by first applying [expression] to each row of the frame and then
69+
*
70+
* @param C is used to infer whether there are nulls in the values fed to the aggregator!
71+
*/
5072
@Suppress("UNCHECKED_CAST")
5173
@PublishedApi
52-
internal inline fun <T, reified C : Any, reified R : Any?> Aggregator<*, R>.aggregateOf(
74+
internal inline fun <T, reified C : Any?, reified R : Any?> Aggregator<*, R>.aggregateOf(
5375
data: Grouped<T>,
5476
resultName: String? = null,
55-
crossinline expression: RowExpression<T, C?>,
77+
crossinline expression: RowExpression<T, C>,
5678
): DataFrame<T> = data.aggregateOf(resultName, expression, this as Aggregator<C, R>)
5779

80+
/**
81+
* @param C is used to infer whether there are nulls in the values fed to the aggregator!
82+
*/
5883
@Suppress("UNCHECKED_CAST")
5984
@PublishedApi
60-
internal inline fun <T, reified C : Any, reified R : Any?> Aggregator<*, R>.aggregateOf(
85+
internal inline fun <T, reified C : Any?, reified R : Any?> Aggregator<*, R>.aggregateOf(
6186
data: PivotGroupBy<T>,
62-
crossinline expression: RowExpression<T, C?>,
63-
): DataFrame<T> = data.aggregateOf<T, C, R>(expression, this as Aggregator<C, R>)
87+
crossinline expression: RowExpression<T, C>,
88+
): DataFrame<T> = data.aggregateOf(expression, this as Aggregator<C, R>)
6489

90+
/**
91+
* @param C is used to infer whether there are nulls in the values fed to the aggregator!
92+
*/
6593
@PublishedApi
66-
internal inline fun <T, reified C : Any, reified R : Any?> Grouped<T>.aggregateOf(
94+
internal inline fun <T, reified C : Any?, reified R : Any?> Grouped<T>.aggregateOf(
6795
resultName: String?,
68-
crossinline expression: RowExpression<T, C?>,
69-
aggregator: Aggregator<C, R>,
96+
crossinline expression: RowExpression<T, C>,
97+
aggregator: Aggregator<C & Any, R>,
7098
): DataFrame<T> {
7199
val path = pathOf(resultName ?: aggregator.name)
72100
val expressionResultType = typeOf<C>()
@@ -86,10 +114,13 @@ internal inline fun <T, reified C : Any, reified R : Any?> Grouped<T>.aggregateO
86114
}
87115
}
88116

117+
/**
118+
* @param C is used to infer whether there are nulls in the values fed to the aggregator!
119+
*/
89120
@PublishedApi
90-
internal inline fun <T, reified C : Any, R : Any?> PivotGroupBy<T>.aggregateOf(
91-
crossinline expression: RowExpression<T, C?>,
92-
aggregator: Aggregator<C, R>,
121+
internal inline fun <T, reified C : Any?, R : Any?> PivotGroupBy<T>.aggregateOf(
122+
crossinline expression: RowExpression<T, C>,
123+
aggregator: Aggregator<C & Any, R>,
93124
): DataFrame<T> =
94125
aggregate {
95126
internal().yield(emptyPath(), aggregator.aggregateOf(this, expression))

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregateCal
1515
* @param columns selector of which columns inside the [row] to aggregate
1616
*/
1717
@PublishedApi
18-
internal fun <T, V : Any, R : Any?> Aggregator<V, R>.aggregateOfRow(
18+
internal fun <T, V : Any?, R : Any?> Aggregator<V & Any, R>.aggregateOfRow(
1919
row: DataRow<T>,
20-
columns: ColumnsSelector<T, V?>,
20+
columns: ColumnsSelector<T, V>,
2121
): R {
2222
val filteredColumns = row.df().getColumns(columns).asSequence()
2323
return aggregateCalculatingValueType(

core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,30 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal
1515
import org.jetbrains.kotlinx.dataframe.impl.emptyPath
1616

1717
@PublishedApi
18-
internal fun <T, C : Any, R : Any?> Aggregator<*, R>.aggregateAll(
18+
internal fun <T, C : Any?, R : Any?> Aggregator<*, R>.aggregateAll(
1919
data: DataFrame<T>,
20-
columns: ColumnsSelector<T, C?>,
21-
): R = data.aggregateAll(cast2(), columns)
20+
columns: ColumnsSelector<T, C>,
21+
): R = data.aggregateAll(cast2<C & Any, R>(), columns)
2222

23-
internal fun <T, C : Any, R : Any?> Aggregator<*, R>.aggregateAll(
23+
internal fun <T, C : Any?, R : Any?> Aggregator<*, R>.aggregateAll(
2424
data: Grouped<T>,
2525
name: String?,
26-
columns: ColumnsSelector<T, C?>,
26+
columns: ColumnsSelector<T, C>,
2727
): DataFrame<T> = data.aggregateAll(cast(), columns, name)
2828

29-
internal fun <T, C : Any, R : Any?> Aggregator<*, R>.aggregateAll(
29+
internal fun <T, C : Any?, R : Any?> Aggregator<*, R>.aggregateAll(
3030
data: PivotGroupBy<T>,
31-
columns: ColumnsSelector<T, C?>,
31+
columns: ColumnsSelector<T, C>,
3232
): DataFrame<T> = data.aggregateAll(cast(), columns)
3333

34-
internal fun <T, C : Any, R : Any?> DataFrame<T>.aggregateAll(
35-
aggregator: Aggregator<C, R>,
36-
columns: ColumnsSelector<T, C?>,
34+
internal fun <T, C : Any?, R : Any?> DataFrame<T>.aggregateAll(
35+
aggregator: Aggregator<C & Any, R>,
36+
columns: ColumnsSelector<T, C>,
3737
): R = aggregator.aggregateMultipleColumns(get(columns).asSequence())
3838

39-
internal fun <T, C : Any, R : Any?> Grouped<T>.aggregateAll(
40-
aggregator: Aggregator<C, R>,
41-
columns: ColumnsSelector<T, C?>,
39+
internal fun <T, C : Any?, R : Any?> Grouped<T>.aggregateAll(
40+
aggregator: Aggregator<C & Any, R>,
41+
columns: ColumnsSelector<T, C>,
4242
name: String?,
4343
): DataFrame<T> =
4444
aggregateInternal {
@@ -50,9 +50,9 @@ internal fun <T, C : Any, R : Any?> Grouped<T>.aggregateAll(
5050
}
5151
}
5252

53-
internal fun <T, C : Any, R : Any?> PivotGroupBy<T>.aggregateAll(
54-
aggregator: Aggregator<C, R>,
55-
columns: ColumnsSelector<T, C?>,
53+
internal fun <T, C : Any?, R : Any?> PivotGroupBy<T>.aggregateAll(
54+
aggregator: Aggregator<C & Any, R>,
55+
columns: ColumnsSelector<T, C>,
5656
): DataFrame<T> =
5757
aggregate {
5858
val cols = get(columns)

0 commit comments

Comments
 (0)