diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 76b470bbf4..c0d9dae69c 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -125,10 +125,13 @@ public fun AnyRow.rowSumOf(type: KType, skipNaN: Boolean = skipNaNDefault): Numb // endregion // region DataFrame - +@Refine +@Interpretable("Sum0") public fun DataFrame.sum(skipNaN: Boolean = skipNaNDefault): DataRow = sumFor(skipNaN, primitiveOrMixedNumberColumns()) +@Refine +@Interpretable("Sum1") public fun DataFrame.sumFor( skipNaN: Boolean = skipNaNDefault, columns: ColumnsForAggregateSelector, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index c5ca55381a..a7cb2a8b14 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -37,12 +37,11 @@ import kotlin.reflect.full.withNullability * @param Return The type of the resulting value. Can optionally be nullable. * @see [invoke] */ -@PublishedApi -internal class Aggregator( - val aggregationHandler: AggregatorAggregationHandler, - val inputHandler: AggregatorInputHandler, - val multipleColumnsHandler: AggregatorMultipleColumnsHandler, - val name: String, +public class Aggregator( + public val aggregationHandler: AggregatorAggregationHandler, + public val inputHandler: AggregatorInputHandler, + public val multipleColumnsHandler: AggregatorMultipleColumnsHandler, + public val name: String, ) : AggregatorInputHandler by inputHandler, AggregatorMultipleColumnsHandler by multipleColumnsHandler, AggregatorAggregationHandler by aggregationHandler { @@ -96,7 +95,7 @@ internal class Aggregator( internal fun Aggregator.aggregate( values: Sequence, valueType: ValueType, -) = aggregateSequence(values, valueType) +): Return = aggregateSequence(values, valueType) /** * Performs aggregation on the given [values], taking [valueType] into account. @@ -106,7 +105,7 @@ internal fun Aggregator.aggregate( internal fun Aggregator.aggregate( values: Sequence, valueType: KType, -) = aggregate(values, valueType.toValueType(needsFullConversion = false)) +): Return = aggregate(values, valueType.toValueType(needsFullConversion = false)) /** * If the specific [ValueType] of the input is not known, but you still want to call [aggregate], diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt index fc4af0ffe5..7b1b0357eb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt @@ -11,8 +11,7 @@ import kotlin.reflect.KType * It also provides information on which return type will be given, as [KType], given a [value type][ValueType]. * It can also provide the index of the result in the input values if it is a selecting aggregator. */ -@PublishedApi -internal interface AggregatorAggregationHandler : AggregatorHandler { +public interface AggregatorAggregationHandler : AggregatorHandler { /** * Base function of [Aggregator]. @@ -23,13 +22,13 @@ internal interface AggregatorAggregationHandler, valueType: ValueType): Return + public fun aggregateSequence(values: Sequence, valueType: ValueType): Return /** * Aggregates the data in the given column and computes a single resulting value. * Calls [aggregateSequence]. */ - fun aggregateSingleColumn(column: DataColumn): Return = + public fun aggregateSingleColumn(column: DataColumn): Return = aggregateSequence( values = column.asSequence(), valueType = column.type().toValueType(), @@ -43,7 +42,7 @@ internal interface AggregatorAggregationHandler, valueType: ValueType): Int + public fun indexOfAggregationResultSingleSequence(values: Sequence, valueType: ValueType): Int } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt index 6d7a87a02d..98558a5556 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt @@ -7,16 +7,16 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators * the [init] function of each [AggregatorAggregationHandlers][AggregatorAggregationHandler] is called, * which allows the handler to refer to [Aggregator] instance via [aggregator]. */ -internal interface AggregatorHandler { +public interface AggregatorHandler { /** * Reference to the aggregator instance. * * Can only be used once [init] has run. */ - var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? + public var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? - fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) { + public fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) { this.aggregator = aggregator } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt index e3a2bb64ac..73c0dcde4b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt @@ -8,13 +8,13 @@ import kotlin.reflect.KType * It can also calculate a specific [value type][ValueType] from the input values or input types * if the (specific) type is not known. */ -internal interface AggregatorInputHandler : AggregatorHandler { +public interface AggregatorInputHandler : AggregatorHandler { /** * If the specific [ValueType] of the input is not known, but you still want to call [aggregate], * this function can be called to calculate it by combining the set of known [valueTypes]. */ - fun calculateValueType(valueTypes: Set): ValueType + public fun calculateValueType(valueTypes: Set): ValueType /** * WARNING: HEAVY! @@ -23,7 +23,7 @@ internal interface AggregatorInputHandler : A * this function can be called to calculate it by getting the types of [values] at runtime. * This is heavy because it uses reflection on each value. */ - fun calculateValueType(values: Sequence): ValueType + public fun calculateValueType(values: Sequence): ValueType /** * Preprocesses the input values before aggregation. @@ -32,7 +32,7 @@ internal interface AggregatorInputHandler : A * * @return A pair of the preprocessed values and the (potentially new) type of the values. */ - fun preprocessAggregation( + public fun preprocessAggregation( values: Sequence, valueType: ValueType, ): Pair, KType> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt index 0530f05ae7..eb5dc1ff59 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt @@ -10,14 +10,14 @@ import kotlin.reflect.KType * [AggregatorAggregationHandler]. * It can also calculate the return type of the aggregation given all input column types. */ -internal interface AggregatorMultipleColumnsHandler : +public interface AggregatorMultipleColumnsHandler : AggregatorHandler { /** * Aggregates the data in the multiple given columns and computes a single resulting value. * Calls [Aggregator.aggregateSequence] or [Aggregator.aggregateSingleColumn]. */ - fun aggregateMultipleColumns(columns: Sequence>): Return + public fun aggregateMultipleColumns(columns: Sequence>): Return /** * Function that can give the return type of [aggregateMultipleColumns], given types of the columns. @@ -26,5 +26,5 @@ internal interface AggregatorMultipleColumnsHandler, colsEmpty: Boolean): KType + public fun calculateReturnTypeMultipleColumns(colTypes: Set, colsEmpty: Boolean): KType } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt index 987771547a..75bd9865d6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt @@ -6,21 +6,20 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators * Aggregators are cached by their parameter value. * @see AggregatorOptionSwitch2 */ -@PublishedApi -internal class AggregatorOptionSwitch1( - val name: String, - val getAggregator: (param1: Param1) -> AggregatorProvider, +public class AggregatorOptionSwitch1( + public val name: String, + public val getAggregator: (param1: Param1) -> AggregatorProvider, ) { private val cache: MutableMap> = mutableMapOf() - operator fun invoke(param1: Param1): Aggregator = + public operator fun invoke(param1: Param1): Aggregator = cache.getOrPut(param1) { getAggregator(param1).create(name) } @Suppress("FunctionName") - companion object { + public companion object { /** * Creates [AggregatorOptionSwitch1]. @@ -31,9 +30,10 @@ internal class AggregatorOptionSwitch1 Factory( + public fun Factory( getAggregator: (param1: Param1) -> AggregatorProvider, - ) = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } + ): Provider> = + Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } } } @@ -43,21 +43,20 @@ internal class AggregatorOptionSwitch1( - val name: String, - val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, +public class AggregatorOptionSwitch2( + public val name: String, + public val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, ) { private val cache: MutableMap, Aggregator> = mutableMapOf() - operator fun invoke(param1: Param1, param2: Param2): Aggregator = + public operator fun invoke(param1: Param1, param2: Param2): Aggregator = cache.getOrPut(param1 to param2) { getAggregator(param1, param2).create(name) } @Suppress("FunctionName") - companion object { + public companion object { /** * Creates [AggregatorOptionSwitch2]. @@ -68,7 +67,7 @@ internal class AggregatorOptionSwitch2 Factory( + internal fun Factory( getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, ) = Provider { name -> AggregatorOptionSwitch2(name, getAggregator) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt index fd1ef8db35..0405f88f1e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt @@ -10,9 +10,9 @@ import kotlin.reflect.KProperty * val myNamedValue by MyFactory * ``` */ -internal fun interface Provider { +public fun interface Provider { - fun create(name: String): T + public fun create(name: String): T } internal operator fun Provider.getValue(obj: Any?, property: KProperty<*>): T = create(property.name) @@ -25,4 +25,5 @@ internal operator fun Provider.getValue(obj: Any?, property: KProperty<*> * val myAggregator by MyAggregator.Factory * ``` */ -internal fun interface AggregatorProvider : Provider> +public fun interface AggregatorProvider : + Provider> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index d023342cac..d754aa3145 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -27,8 +27,7 @@ import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion import org.jetbrains.kotlinx.dataframe.math.sum import org.jetbrains.kotlinx.dataframe.math.sumTypeConversion -@PublishedApi -internal object Aggregators { +public object Aggregators { // TODO these might need some small refactoring @@ -112,7 +111,7 @@ internal object Aggregators { // T: Comparable -> T? // T : Comparable? -> T? - fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() + public fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() private val min by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( @@ -124,7 +123,7 @@ internal object Aggregators { // T: Comparable -> T? // T : Comparable? -> T? - fun ?> max(skipNaN: Boolean): Aggregator = max.invoke(skipNaN).cast2() + public fun ?> max(skipNaN: Boolean): Aggregator = max.invoke(skipNaN).cast2() private val max by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( @@ -135,7 +134,10 @@ internal object Aggregators { } // T: Number? -> Double - val std by withTwoOptions { skipNaN: Boolean, ddof: Int -> + public val std: AggregatorOptionSwitch2 by withTwoOptions { + skipNaN: Boolean, + ddof: Int, + -> flattenReducingForNumbers(stdTypeConversion) { type -> std(type, skipNaN, ddof) } @@ -143,7 +145,7 @@ internal object Aggregators { // step one: T: Number? -> Double // step two: Double -> Double - val mean by withOneOption { skipNaN: Boolean -> + public val mean: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> twoStepReducingForNumbers(meanTypeConversion) { type -> mean(type, skipNaN) } @@ -151,7 +153,7 @@ internal object Aggregators { // T : primitive Number? -> Double? // T : Comparable? -> T? - fun percentileCommon( + public fun percentileCommon( percentile: Double, skipNaN: Boolean, ): Aggregator @@ -159,12 +161,14 @@ internal object Aggregators { this.percentile.invoke(percentile, skipNaN).cast2() // T : Comparable? -> T? - fun percentileComparables(percentile: Double): Aggregator + public fun percentileComparables( + percentile: Double, + ): Aggregator where T : Comparable? = percentileCommon(percentile, skipNaNDefault).cast2() // T : primitive Number? -> Double? - fun percentileNumbers( + public fun percentileNumbers( percentile: Double, skipNaN: Boolean, ): Aggregator @@ -182,17 +186,17 @@ internal object Aggregators { // T : primitive Number? -> Double? // T : Comparable? -> T? - fun medianCommon(skipNaN: Boolean): Aggregator + public fun medianCommon(skipNaN: Boolean): Aggregator where T : Comparable? = median.invoke(skipNaN).cast2() // T : Comparable? -> T? - fun medianComparables(): Aggregator + public fun medianComparables(): Aggregator where T : Comparable? = medianCommon(skipNaNDefault).cast2() // T : primitive Number? -> Double? - fun medianNumbers( + public fun medianNumbers( skipNaN: Boolean, ): Aggregator where T : Comparable?, T : Number? = @@ -211,7 +215,7 @@ internal object Aggregators { // Byte -> Int // Short -> Int // Nothing -> Double - val sum by withOneOption { skipNaN: Boolean -> + public val sum: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> twoStepReducingForNumbers(sumTypeConversion) { type -> sum(type, skipNaN) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt index be886c9578..91157483ff 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt @@ -10,6 +10,6 @@ import kotlin.reflect.KType * for the values to become the correct value type. If `false`, the values are already the right type, * or a simple cast will suffice. */ -internal data class ValueType(val kType: KType, val needsFullConversion: Boolean = false) +public data class ValueType(val kType: KType, val needsFullConversion: Boolean = false) internal fun KType.toValueType(needsFullConversion: Boolean = false): ValueType = ValueType(this, needsFullConversion) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt index e53e6470a4..c24a3e34a8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt @@ -31,7 +31,7 @@ import kotlin.reflect.KType * If not supplied, the handler of the first step is reused. * @see [FlatteningMultipleColumnsHandler] */ -internal class TwoStepMultipleColumnsHandler( +internal class TwoStepMultipleColumnsHandler( stepTwoAggregationHandler: AggregatorAggregationHandler? = null, stepTwoInputHandler: AggregatorInputHandler? = null, ) : AggregatorMultipleColumnsHandler { diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index 546df300e1..726ed62a59 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -5,24 +5,98 @@ import org.junit.Test @Suppress("ktlint:standard:argument-list-wrapping") class StatisticsTests { - private val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( - "Alice", 15, "London", 99.5, "1.85", 50, - "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75.0, "1.95", 0, - "Rose", 1, "Moscow", 45.33, "0.79", 64, - "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.72, "1.85", 25, - "Frank", 55, "Dubai", 78.9, "1.35", 10, - "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.22, "1.75", 5, - "Isla", 22, "London", 75.1, "1.85", 43, + private val personsDf = dataFrameOf( + "name", + "age", + "city", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + )( + "Alice", 15, "London", 99.5, "1.85", 50, 0.toShort(), 0.toByte(), 0L, + "Bob", 20, "Paris", 140.0, "1.35", 45, 2.toShort(), 0.toByte(), 12000L, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, 70.toShort(), 0.toByte(), 0L, + "Rose", 1, "Moscow", 45.33, "0.79", 64, 0.toShort(), 2.toByte(), 0L, + "Dylan", 35, "London", 23.4, "1.83", 30, 15.toShort(), 1.toByte(), 90000L, + "Eve", 40, "Paris", 56.72, "1.85", 25, 18.toShort(), 3.toByte(), 125000L, + "Frank", 55, "Dubai", 78.9, "1.35", 10, 35.toShort(), 2.toByte(), 145000L, + "Grace", 29, "Moscow", 67.8, "1.65", 36, 5.toShort(), 1.toByte(), 70000L, + "Hank", 60, "Paris", 80.22, "1.75", 5, 40.toShort(), 4.toByte(), 200000L, + "Isla", 22, "London", 75.1, "1.85", 43, 1.toShort(), 0.toByte(), 30000L, ) + @Test + fun `sum on DataFrame`() { + // scenario #0: all numerical columns + val res0 = personsDf.sum() + res0.columnNames() shouldBe listOf( + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) + + val sum01 = res0["age"] as Int + sum01 shouldBe 377 + val sum02 = res0["weight"] as Double + sum02 shouldBe 741.9699999999999 + val sum03 = res0["yearsToRetirement"] as Int + sum03 shouldBe 308 + val sum04 = res0["workExperienceYears"] as Int + sum04 shouldBe 186 + val sum05 = res0["dependentsCount"] as Int + sum05 shouldBe 13.0 + val sum06 = res0["annualIncome"] as Long + sum06 shouldBe 672000 + + // scenario #1: particular column + val res1 = personsDf.sumFor("age") + res1.columnNames() shouldBe listOf("age") + + val sum11 = res1["age"] as Int + sum11 shouldBe 377 + + // scenario #1.1: particular column with converted type + val res11 = personsDf.sumFor("dependentsCount") + res11.columnNames() shouldBe listOf("dependentsCount") + + val sum111 = res11["dependentsCount"] as Int + sum111 shouldBe 13 + + // scenario #2: sum of values per columns separately + val res3 = personsDf.sumFor("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") + res3.columnNames() shouldBe listOf("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") + + val sum31 = res3["age"] as Int + sum31 shouldBe 377 + val sum32 = res0["weight"] as Double + sum32 shouldBe 741.9699999999999 + val sum33 = res0["workExperienceYears"] as Int + sum33 shouldBe 186 + val sum34 = res0["dependentsCount"] as Int + sum34 shouldBe 13.0 + val sum35 = res0["annualIncome"] as Long + sum35 shouldBe 672000 + } + @Test fun `sum on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").sum() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val sum01 = res0["age"][0] as Int sum01 shouldBe 72 @@ -83,7 +157,15 @@ class StatisticsTests { fun `mean on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").mean() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val mean01 = res0["age"][0] as Double mean01 shouldBe 24.0 @@ -151,6 +233,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", ) val median01 = res0["age"][0] as Double @@ -212,7 +297,15 @@ class StatisticsTests { fun `std on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").std() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val std01 = res0["age"][0] as Double std01 shouldBe 10.14889156509222 @@ -280,6 +373,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", ) val min01 = res0["age"][0] as Int @@ -345,6 +441,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", ) // TODO: why is here weight presented? looks like inconsitency val min41 = res4["age"][0] as Int @@ -354,7 +453,17 @@ class StatisticsTests { // scenario #5: particular column via minBy and rowExpression val res5 = personsDf.groupBy("city").minBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val min51 = res5["age"][0] as Int min51 shouldBe 15 @@ -364,7 +473,17 @@ class StatisticsTests { fun `max on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").max() - res0.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val max01 = res0["age"][0] as Int max01 shouldBe 35 @@ -429,6 +548,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", ) // TODO: weight is here? val max41 = res4["age"][0] as Int @@ -438,7 +560,17 @@ class StatisticsTests { // scenario #5: particular column via maxBy and rowExpression val res5 = personsDf.groupBy("city").maxBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome", + ) val max51 = res5["age"][0] as Int max51 shouldBe 35 diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt new file mode 100644 index 0000000000..079556098d --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt @@ -0,0 +1,124 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlin.fir.types.ConeClassLikeErrorLookupTag +import org.jetbrains.kotlin.fir.types.isSubtypeOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf +import kotlin.reflect.KType +import kotlin.reflect.full.createType +import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.ConeClassLikeType +import org.jetbrains.kotlin.fir.types.ConeNullability +import org.jetbrains.kotlin.fir.types.constructClassLikeType +import org.jetbrains.kotlin.fir.types.impl.ConeClassLikeTypeImpl +import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol +import kotlin.reflect.KClass + +private object PrimitiveClassIds { + const val INT = "kotlin/Int" + const val LONG = "kotlin/Long" + const val DOUBLE = "kotlin/Double" + const val FLOAT = "kotlin/Float" + const val SHORT = "kotlin/Short" + const val BYTE = "kotlin/Byte" +} + +private fun KClass<*>.toClassId(): ClassId? = when (this) { + Int::class -> ClassId.fromString(PrimitiveClassIds.INT) + Long::class -> ClassId.fromString(PrimitiveClassIds.LONG) + Double::class -> ClassId.fromString(PrimitiveClassIds.DOUBLE) + Float::class -> ClassId.fromString(PrimitiveClassIds.FLOAT) + Short::class -> ClassId.fromString(PrimitiveClassIds.SHORT) + Byte::class -> ClassId.fromString(PrimitiveClassIds.BYTE) + else -> null +} + +private val primitiveTypeMap = mapOf( + PrimitiveClassIds.INT to Int::class, + PrimitiveClassIds.LONG to Long::class, + PrimitiveClassIds.DOUBLE to Double::class, + PrimitiveClassIds.FLOAT to Float::class, + PrimitiveClassIds.SHORT to Short::class, + PrimitiveClassIds.BYTE to Byte::class +) + +fun ConeKotlinType.toKType(): KType? { + return (this as? ConeClassLikeType)?.let { coneType -> + val nullable = coneType.nullability == ConeNullability.NULLABLE + primitiveTypeMap[coneType.lookupTag.classId.asString()] + ?.createType(nullable = nullable) + } +} + +fun KType.toConeKotlinType(): ConeKotlinType? { + val kClass = this.classifier as? KClass<*> ?: return null + val classId = kClass.toClassId() ?: return null + + return classId.constructClassLikeType( + typeArguments = emptyArray(), + isNullable = this.isMarkedNullable + ) +} + +private fun Arguments.generateStatisticResultColumns( + statisticAggregator: Aggregator, + inputColumns: List +): List { + return inputColumns.map { col -> createUpdatedColumn(col, statisticAggregator) } +} + +private fun Arguments.createUpdatedColumn( + column: SimpleDataColumn, + statisticAggregator: Aggregator +): SimpleCol { + val originalType = column.type.type + val inputKType = originalType.toKType() + val resultKType = inputKType?.let { statisticAggregator.calculateReturnType(it, emptyInput = true) } + val updatedType = resultKType?.toConeKotlinType() ?: originalType + return simpleColumnOf(column.name, updatedType) +} + +val skipNaN = true +val sum = Aggregators.sum(skipNaN) + +/** Adds to the schema only numerical columns. */ +abstract class Aggregator0 : AbstractSchemaModificationInterpreter() { + private val Arguments.receiver: PluginDataFrameSchema by dataFrame() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = receiver.columns() + .filterIsInstance() + .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } + + val newColumns = generateStatisticResultColumns(sum, resolvedColumns) + + return PluginDataFrameSchema(receiver.columns() + newColumns) + } +} + +/** Implementation for `sum`. */ +class Sum0 : Aggregator0() + +/** Adds to the schema only numerical columns. */ +abstract class Aggregator1 : AbstractSchemaModificationInterpreter() { + private val Arguments.receiver: PluginDataFrameSchema by dataFrame() + private val Arguments.columns: ColumnsResolver by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = columns.resolve(receiver).map { it.column }.filterIsInstance().toList() + + val newColumns = generateStatisticResultColumns(sum, resolvedColumns) + + return PluginDataFrameSchema(receiver.columns() + newColumns) + } +} + +/** Implementation for `sum`. */ +class Sum1 : Aggregator1() diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index fe48669d96..b9d75ff2e0 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -200,6 +200,8 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReorderColumnsByName import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single2 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Sum0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Sum1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ValueCols2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take1 @@ -453,6 +455,8 @@ internal inline fun String.load(): T { "ReorderColumnsByName" -> ReorderColumnsByName() "Reorder" -> Reorder() "ByName" -> ByName() + "Sum0" -> Sum0() + "Sum1" -> Sum1() "GroupByCount0" -> GroupByCount0() "GroupByMean0" -> GroupByMean0() "GroupByMean1" -> GroupByMean1() diff --git a/plugins/kotlin-dataframe/testData/box/sum.kt b/plugins/kotlin-dataframe/testData/box/sum.kt new file mode 100644 index 0000000000..09d6a0b63f --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/sum.kt @@ -0,0 +1,59 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf( + "name", + "age", + "city", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + )( + "Alice", 15, "London", 99.5, "1.85", 50, 0.toShort(), 0.toByte(), 0L, + "Bob", 20, "Paris", 140.0, "1.35", 45, 2.toShort(), 0.toByte(), 12000L, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, 70.toShort(), 0.toByte(), 0L, + "Rose", 1, "Moscow", 45.33, "0.79", 64, 0.toShort(), 2.toByte(), 0L, + "Dylan", 35, "London", 23.4, "1.83", 30, 15.toShort(), 1.toByte(), 90000L, + "Eve", 40, "Paris", 56.72, "1.85", 25, 18.toShort(), 3.toByte(), 125000L, + "Frank", 55, "Dubai", 78.9, "1.35", 10, 35.toShort(), 2.toByte(), 145000L, + "Grace", 29, "Moscow", 67.8, "1.65", 36, 5.toShort(), 1.toByte(), 70000L, + "Hank", 60, "Paris", 80.22, "1.75", 5, 40.toShort(), 4.toByte(), 200000L, + "Isla", 22, "London", 75.1, "1.85", 43, 1.toShort(), 0.toByte(), 30000L, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.sum() + + val sum01: Int? = res0.age + val sum02: Double? = res0.weight + val sum03: Int? = res0.yearsToRetirement + val sum04: Int? = res0.workExperienceYears + val sum05: Int? = res0.dependentsCount + val sum06: Long? = res0.annualIncome + + // scenario #1: particular column + val res1 = personsDf.sumFor { age } + val sum11: Int? = res1.age + + // scenario #1.1: particular column with converted type + val res11 = personsDf.sumFor { dependentsCount } + val sum111: Int? = res11.dependentsCount + + // scenario #2: sum of values per columns separately + val res3 = personsDf.sumFor { age and weight and workExperienceYears and dependentsCount and annualIncome } + + val sum31: Int? = res3.age + val sum32: Double? = res3.weight + val sum33: Int? = res3.workExperienceYears + val sum34: Int? = res3.dependentsCount + val sum35: Long? = res3.annualIncome + + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index ca695cfb4f..2a138dca35 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -6,7 +6,6 @@ import org.jetbrains.kotlin.test.util.KtTestUtil; import org.jetbrains.kotlin.test.TargetBackend; import org.jetbrains.kotlin.test.TestMetadata; -import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -308,8 +307,7 @@ public void testGroupBy_mean() { @Test @TestMetadata("groupBy_median.kt") public void testGroupBy_median() { - Assumptions.assumeTrue(false, "ignoring median test while compiler plugin support is pending."); - runTest("testData/box/groupBy_median.kt"); + runTest("testData/box/groupBy_median.kt"); } @Test @@ -612,6 +610,12 @@ public void testSelectionDsl() { runTest("testData/box/selectionDsl.kt"); } + @Test + @TestMetadata("sum.kt") + public void testSum() { + runTest("testData/box/sum.kt"); + } + @Test @TestMetadata("toDataFrame.kt") public void testToDataFrame() {