From 7729d0386a9fdd9534e298804e8c71a089dd8611 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 23 Apr 2025 15:24:35 +0200 Subject: [PATCH 1/8] Add support for DataFrame `sum` operation with tests Introduced the `sum` operation for DataFrames, supporting numerical columns aggregation. Updated relevant tests and added new test cases to verify functionality. Included schema modifications for handling numerical column operations. --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 3 +- .../kotlinx/dataframe/api/statistics.kt | 175 ++++++++++++++++-- .../dataframe/plugin/impl/api/statistics.kt | 34 ++++ .../dataframe/plugin/loadInterpreter.kt | 2 + plugins/kotlin-dataframe/testData/box/sum.kt | 71 +++++++ ...DataFrameBlackBoxCodegenTestGenerated.java | 6 + 6 files changed, 273 insertions(+), 18 deletions(-) create mode 100644 plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt create mode 100644 plugins/kotlin-dataframe/testData/box/sum.kt diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 76b470bbf4..10b38c100a 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -125,7 +125,8 @@ public fun AnyRow.rowSumOf(type: KType, skipNaN: Boolean = skipNaNDefault): Numb // endregion // region DataFrame - +@Refine +@Interpretable("Sum0") public fun DataFrame.sum(skipNaN: Boolean = skipNaNDefault): DataRow = sumFor(skipNaN, primitiveOrMixedNumberColumns()) diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index 1d4d76b637..7a67556adc 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -5,24 +5,107 @@ import org.junit.Test @Suppress("ktlint:standard:argument-list-wrapping") class StatisticsTests { - private val personsDf = dataFrameOf("name", "age", "city", "weight", "height", "yearsToRetirement")( - "Alice", 15, "London", 99.5, "1.85", 50, - "Bob", 20, "Paris", 140.0, "1.35", 45, - "Charlie", 100, "Dubai", 75.0, "1.95", 0, - "Rose", 1, "Moscow", 45.33, "0.79", 64, - "Dylan", 35, "London", 23.4, "1.83", 30, - "Eve", 40, "Paris", 56.72, "1.85", 25, - "Frank", 55, "Dubai", 78.9, "1.35", 10, - "Grace", 29, "Moscow", 67.8, "1.65", 36, - "Hank", 60, "Paris", 80.22, "1.75", 5, - "Isla", 22, "London", 75.1, "1.85", 43, + private val personsDf = dataFrameOf( + "name", + "age", + "city", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + )( + "Alice", 15, "London", 99.5, "1.85", 50, 0.toShort(), 0.toByte(), 0L, + "Bob", 20, "Paris", 140.0, "1.35", 45, 2.toShort(), 0.toByte(), 12000L, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, 70.toShort(), 0.toByte(), 0L, + "Rose", 1, "Moscow", 45.33, "0.79", 64, 0.toShort(), 2.toByte(), 0L, + "Dylan", 35, "London", 23.4, "1.83", 30, 15.toShort(), 1.toByte(), 90000L, + "Eve", 40, "Paris", 56.72, "1.85", 25, 18.toShort(), 3.toByte(), 125000L, + "Frank", 55, "Dubai", 78.9, "1.35", 10, 35.toShort(), 2.toByte(), 145000L, + "Grace", 29, "Moscow", 67.8, "1.65", 36, 5.toShort(), 1.toByte(), 70000L, + "Hank", 60, "Paris", 80.22, "1.75", 5, 40.toShort(), 4.toByte(), 200000L, + "Isla", 22, "London", 75.1, "1.85", 43, 1.toShort(), 0.toByte(), 30000L, ) + + @Test + fun `sum on DataFrame`() { + // scenario #0: all numerical columns + val res0 = personsDf.sum() + res0.columnNames() shouldBe listOf( + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) + + val sum01 = res0["age"] as Int + sum01 shouldBe 377 + val sum02 = res0["weight"] as Double + sum02 shouldBe 741.9699999999999 + val sum03 = res0["yearsToRetirement"] as Int + sum03 shouldBe 308 + val sum04 = res0["workExperienceYears"] as Int + sum04 shouldBe 186 + val sum05 = res0["dependentsCount"] as Int + sum05 shouldBe 13.0 + val sum06 = res0["annualIncome"] as Long + sum06 shouldBe 672000 + + // scenario #1: particular column + val res1 = personsDf.sumFor("age") + res1.columnNames() shouldBe listOf("age") + + val sum11 = res1["age"] as Int + sum11 shouldBe 377 + + // scenario #2: sum of all values in two columns of Int hierarchy of types + val res2 = personsDf.sum("age", "workExperienceYears") + res2 shouldBe 563 + + // scenario #2.1: sum of all values in two columns of different types + val res21 = personsDf.sum("age", "annualIncome") + res21 shouldBe 672377L + + val res211 = personsDf.sum("age", "weight") + res211 shouldBe 1118.9699999999998 + + // scenario #3: sum of values per columns separately + val res3 = personsDf.sumFor( "age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") + res3.columnNames() shouldBe listOf("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") + + val sum31 = res3["age"] as Int + sum31 shouldBe 377 + val sum32 = res0["weight"] as Double + sum32 shouldBe 741.9699999999999 + val sum33 = res0["workExperienceYears"] as Int + sum33 shouldBe 186 + val sum34 = res0["dependentsCount"] as Int + sum34 shouldBe 13.0 + val sum35 = res0["annualIncome"] as Long + sum35 shouldBe 672000 + + // scenario #4: sum of expression evaluated for every row + val res4 = personsDf.sumOf { "weight"() * 10 + "age"() } + res4 shouldBe 7796.7 + } + @Test fun `sum on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").sum() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val sum01 = res0["age"][0] as Int sum01 shouldBe 72 @@ -83,7 +166,15 @@ class StatisticsTests { fun `mean on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").mean() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val mean01 = res0["age"][0] as Double mean01 shouldBe 24.0 @@ -151,6 +242,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" ) val median01 = res0["age"][0] as Int @@ -212,7 +306,15 @@ class StatisticsTests { fun `std on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").std() - res0.columnNames() shouldBe listOf("city", "age", "weight", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "age", + "weight", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val std01 = res0["age"][0] as Double std01 shouldBe 10.14889156509222 @@ -280,6 +382,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" ) val min01 = res0["age"][0] as Int @@ -345,6 +450,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" ) // TODO: why is here weight presented? looks like inconsitency val min41 = res4["age"][0] as Int @@ -354,7 +462,17 @@ class StatisticsTests { // scenario #5: particular column via minBy and rowExpression val res5 = personsDf.groupBy("city").minBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val min51 = res5["age"][0] as Int min51 shouldBe 15 @@ -364,7 +482,17 @@ class StatisticsTests { fun `max on GroupBy`() { // scenario #0: all numerical columns val res0 = personsDf.groupBy("city").max() - res0.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res0.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val max01 = res0["age"][0] as Int max01 shouldBe 35 @@ -429,6 +557,9 @@ class StatisticsTests { "weight", "height", "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" ) // TODO: weight is here? val max41 = res4["age"][0] as Int @@ -438,7 +569,17 @@ class StatisticsTests { // scenario #5: particular column via maxBy and rowExpression val res5 = personsDf.groupBy("city").maxBy { "age"() * 10 }.values() - res4.columnNames() shouldBe listOf("city", "name", "age", "weight", "height", "yearsToRetirement") + res4.columnNames() shouldBe listOf( + "city", + "name", + "age", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + ) val max51 = res5["age"][0] as Int max51 shouldBe 35 diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt new file mode 100644 index 0000000000..d54cf7c4fa --- /dev/null +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt @@ -0,0 +1,34 @@ +package org.jetbrains.kotlinx.dataframe.plugin.impl.api + +import org.jetbrains.kotlin.fir.types.isSubtypeOf +import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter +import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments +import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn +import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame +import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf + +/** Adds to the schema only numerical columns. */ +abstract class Aggregator0 : AbstractSchemaModificationInterpreter() { + private val Arguments.receiver: PluginDataFrameSchema by dataFrame() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = receiver.columns() + .filterIsInstance() + .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } + + /*val skipNaN = true + val sum = Aggregators.sum(skipNaN)*/ + + /* val newColumns = resolvedColumns + .map { col -> + simpleColumnOf(col.name, session.builtinTypes.doubleType.type) + } + .toList()*/ + + return PluginDataFrameSchema(receiver.columns() + resolvedColumns) + } +} + +/** Implementation for `sum`. */ +class Sum0 : Aggregator0() diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index fe15655f98..4718ab3d1a 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -194,6 +194,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ReorderColumnsByName import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single2 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Sum0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ValueCols2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take1 @@ -440,6 +441,7 @@ internal inline fun String.load(): T { "ReorderColumnsByName" -> ReorderColumnsByName() "Reorder" -> Reorder() "ByName" -> ByName() + "Sum0" -> Sum0() "GroupByCount0" -> GroupByCount0() "GroupByMean0" -> GroupByMean0() "GroupByMean1" -> GroupByMean1() diff --git a/plugins/kotlin-dataframe/testData/box/sum.kt b/plugins/kotlin-dataframe/testData/box/sum.kt new file mode 100644 index 0000000000..23c1147a94 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/sum.kt @@ -0,0 +1,71 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +fun box(): String { + // multiple columns + val personsDf = dataFrameOf( + "name", + "age", + "city", + "weight", + "height", + "yearsToRetirement", + "workExperienceYears", + "dependentsCount", + "annualIncome" + )( + "Alice", 15, "London", 99.5, "1.85", 50, 0.toShort(), 0.toByte(), 0L, + "Bob", 20, "Paris", 140.0, "1.35", 45, 2.toShort(), 0.toByte(), 12000L, + "Charlie", 100, "Dubai", 75.0, "1.95", 0, 70.toShort(), 0.toByte(), 0L, + "Rose", 1, "Moscow", 45.33, "0.79", 64, 0.toShort(), 2.toByte(), 0L, + "Dylan", 35, "London", 23.4, "1.83", 30, 15.toShort(), 1.toByte(), 90000L, + "Eve", 40, "Paris", 56.72, "1.85", 25, 18.toShort(), 3.toByte(), 125000L, + "Frank", 55, "Dubai", 78.9, "1.35", 10, 35.toShort(), 2.toByte(), 145000L, + "Grace", 29, "Moscow", 67.8, "1.65", 36, 5.toShort(), 1.toByte(), 70000L, + "Hank", 60, "Paris", 80.22, "1.75", 5, 40.toShort(), 4.toByte(), 200000L, + "Isla", 22, "London", 75.1, "1.85", 43, 1.toShort(), 0.toByte(), 30000L, + ) + + // scenario #0: all numerical columns + val res0 = personsDf.sum() + + val sum01: Int? = res0.age + val sum02: Double? = res0.weight + val sum03: Int? = res0.yearsToRetirement + val sum04: Int? = res0.workExperienceYears + val sum05: Int? = res0.dependentsCount + val sum06: Long? = res0.annualIncome + + /*// scenario #1: particular column + val res1 = personsDf.sumFor { age } + val sum11: Int? = res1.age + + // scenario #2: sum of all values in two columns of Int hierarchy of types + val res2 = personsDf.sum { age and workExperienceYears } + val intRes: Int? = res2 + + // scenario #2.1: sum of all values in two columns of different types + val res21 = personsDf.sum { age and annualIncome } + val longRes: Long? = res21 + + val res211 = personsDf.sum { age and weight } + val doubleRes: Double? = res211 + + // scenario #3: sum of values per columns separately + val res3 = personsDf.sumFor { age and weight and workExperienceYears and dependentsCount and annualIncome } + res3.compareSchemas() + + val sum31: Int? = res0.age + val sum32: Double? = res0.weight + val sum33: Int? = res0.workExperienceYears + val sum34: Int? = res0.dependentsCount + val sum35: Long? = res0.annualIncome + + // scenario #4: sum of expression evaluated for every row + val res4 = personsDf.sumOf { weight * 10 + age } + val expressionDoubleRes: Double? = res4*/ + + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 2910a80138..8d04e62599 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -226,6 +226,12 @@ public void testFlexibleReturnType() { runTest("testData/box/flexibleReturnType.kt"); } + @Test + @TestMetadata("sum.kt") + public void testSum() { + runTest("testData/box/sum.kt"); + } + @Test @TestMetadata("group.kt") public void testGroup() { From 0376fec5d858169e73a3bac65d99df4682145c67 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 23 Apr 2025 17:07:26 +0200 Subject: [PATCH 2/8] Make aggregator-related classes and functions public Converted various internal classes, interfaces, and functions related to aggregation into public entities. This change expands their visibility, enabling external usage and facilitating integration with other modules or libraries. --- .../aggregation/aggregators/Aggregator.kt | 15 +++++------ .../AggregatorAggregationHandler.kt | 11 ++++---- .../aggregators/AggregatorHandler.kt | 6 ++--- .../aggregators/AggregatorInputHandler.kt | 8 +++--- .../AggregatorMultipleColumnsHandler.kt | 6 ++--- .../aggregators/AggregatorOptionSwitch.kt | 26 +++++++++---------- .../aggregators/AggregatorProvider.kt | 6 ++--- .../aggregation/aggregators/Aggregators.kt | 21 +++++++-------- .../impl/aggregation/aggregators/ValueType.kt | 2 +- .../TwoStepMultipleColumnsHandler.kt | 2 +- 10 files changed, 49 insertions(+), 54 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index c5ca55381a..9d181a6804 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -37,12 +37,11 @@ import kotlin.reflect.full.withNullability * @param Return The type of the resulting value. Can optionally be nullable. * @see [invoke] */ -@PublishedApi -internal class Aggregator( - val aggregationHandler: AggregatorAggregationHandler, - val inputHandler: AggregatorInputHandler, - val multipleColumnsHandler: AggregatorMultipleColumnsHandler, - val name: String, +public class Aggregator( + public val aggregationHandler: AggregatorAggregationHandler, + public val inputHandler: AggregatorInputHandler, + public val multipleColumnsHandler: AggregatorMultipleColumnsHandler, + public val name: String, ) : AggregatorInputHandler by inputHandler, AggregatorMultipleColumnsHandler by multipleColumnsHandler, AggregatorAggregationHandler by aggregationHandler { @@ -96,7 +95,7 @@ internal class Aggregator( internal fun Aggregator.aggregate( values: Sequence, valueType: ValueType, -) = aggregateSequence(values, valueType) +): Return = aggregateSequence(values, valueType) /** * Performs aggregation on the given [values], taking [valueType] into account. @@ -106,7 +105,7 @@ internal fun Aggregator.aggregate( internal fun Aggregator.aggregate( values: Sequence, valueType: KType, -) = aggregate(values, valueType.toValueType(needsFullConversion = false)) +): Return = aggregate(values, valueType.toValueType(needsFullConversion = false)) /** * If the specific [ValueType] of the input is not known, but you still want to call [aggregate], diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt index fc4af0ffe5..7b1b0357eb 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler.kt @@ -11,8 +11,7 @@ import kotlin.reflect.KType * It also provides information on which return type will be given, as [KType], given a [value type][ValueType]. * It can also provide the index of the result in the input values if it is a selecting aggregator. */ -@PublishedApi -internal interface AggregatorAggregationHandler : AggregatorHandler { +public interface AggregatorAggregationHandler : AggregatorHandler { /** * Base function of [Aggregator]. @@ -23,13 +22,13 @@ internal interface AggregatorAggregationHandler, valueType: ValueType): Return + public fun aggregateSequence(values: Sequence, valueType: ValueType): Return /** * Aggregates the data in the given column and computes a single resulting value. * Calls [aggregateSequence]. */ - fun aggregateSingleColumn(column: DataColumn): Return = + public fun aggregateSingleColumn(column: DataColumn): Return = aggregateSequence( values = column.asSequence(), valueType = column.type().toValueType(), @@ -43,7 +42,7 @@ internal interface AggregatorAggregationHandler, valueType: ValueType): Int + public fun indexOfAggregationResultSingleSequence(values: Sequence, valueType: ValueType): Int } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt index 6d7a87a02d..98558a5556 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorHandler.kt @@ -7,16 +7,16 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators * the [init] function of each [AggregatorAggregationHandlers][AggregatorAggregationHandler] is called, * which allows the handler to refer to [Aggregator] instance via [aggregator]. */ -internal interface AggregatorHandler { +public interface AggregatorHandler { /** * Reference to the aggregator instance. * * Can only be used once [init] has run. */ - var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? + public var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>? - fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) { + public fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) { this.aggregator = aggregator } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt index e3a2bb64ac..73c0dcde4b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler.kt @@ -8,13 +8,13 @@ import kotlin.reflect.KType * It can also calculate a specific [value type][ValueType] from the input values or input types * if the (specific) type is not known. */ -internal interface AggregatorInputHandler : AggregatorHandler { +public interface AggregatorInputHandler : AggregatorHandler { /** * If the specific [ValueType] of the input is not known, but you still want to call [aggregate], * this function can be called to calculate it by combining the set of known [valueTypes]. */ - fun calculateValueType(valueTypes: Set): ValueType + public fun calculateValueType(valueTypes: Set): ValueType /** * WARNING: HEAVY! @@ -23,7 +23,7 @@ internal interface AggregatorInputHandler : A * this function can be called to calculate it by getting the types of [values] at runtime. * This is heavy because it uses reflection on each value. */ - fun calculateValueType(values: Sequence): ValueType + public fun calculateValueType(values: Sequence): ValueType /** * Preprocesses the input values before aggregation. @@ -32,7 +32,7 @@ internal interface AggregatorInputHandler : A * * @return A pair of the preprocessed values and the (potentially new) type of the values. */ - fun preprocessAggregation( + public fun preprocessAggregation( values: Sequence, valueType: ValueType, ): Pair, KType> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt index 0530f05ae7..eb5dc1ff59 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler.kt @@ -10,14 +10,14 @@ import kotlin.reflect.KType * [AggregatorAggregationHandler]. * It can also calculate the return type of the aggregation given all input column types. */ -internal interface AggregatorMultipleColumnsHandler : +public interface AggregatorMultipleColumnsHandler : AggregatorHandler { /** * Aggregates the data in the multiple given columns and computes a single resulting value. * Calls [Aggregator.aggregateSequence] or [Aggregator.aggregateSingleColumn]. */ - fun aggregateMultipleColumns(columns: Sequence>): Return + public fun aggregateMultipleColumns(columns: Sequence>): Return /** * Function that can give the return type of [aggregateMultipleColumns], given types of the columns. @@ -26,5 +26,5 @@ internal interface AggregatorMultipleColumnsHandler, colsEmpty: Boolean): KType + public fun calculateReturnTypeMultipleColumns(colTypes: Set, colsEmpty: Boolean): KType } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt index 987771547a..87a1a37460 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt @@ -6,21 +6,20 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators * Aggregators are cached by their parameter value. * @see AggregatorOptionSwitch2 */ -@PublishedApi -internal class AggregatorOptionSwitch1( - val name: String, - val getAggregator: (param1: Param1) -> AggregatorProvider, +public class AggregatorOptionSwitch1( + public val name: String, + public val getAggregator: (param1: Param1) -> AggregatorProvider, ) { private val cache: MutableMap> = mutableMapOf() - operator fun invoke(param1: Param1): Aggregator = + public operator fun invoke(param1: Param1): Aggregator = cache.getOrPut(param1) { getAggregator(param1).create(name) } @Suppress("FunctionName") - companion object { + public companion object { /** * Creates [AggregatorOptionSwitch1]. @@ -31,7 +30,7 @@ internal class AggregatorOptionSwitch1 Factory( + internal fun Factory( getAggregator: (param1: Param1) -> AggregatorProvider, ) = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } } @@ -43,21 +42,20 @@ internal class AggregatorOptionSwitch1( - val name: String, - val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, +public class AggregatorOptionSwitch2( + public val name: String, + public val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, ) { private val cache: MutableMap, Aggregator> = mutableMapOf() - operator fun invoke(param1: Param1, param2: Param2): Aggregator = + public operator fun invoke(param1: Param1, param2: Param2): Aggregator = cache.getOrPut(param1 to param2) { getAggregator(param1, param2).create(name) } @Suppress("FunctionName") - companion object { + public companion object { /** * Creates [AggregatorOptionSwitch2]. @@ -68,7 +66,7 @@ internal class AggregatorOptionSwitch2 Factory( + internal fun Factory( getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider, ) = Provider { name -> AggregatorOptionSwitch2(name, getAggregator) } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt index fd1ef8db35..158995cfea 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt @@ -10,9 +10,9 @@ import kotlin.reflect.KProperty * val myNamedValue by MyFactory * ``` */ -internal fun interface Provider { +public fun interface Provider { - fun create(name: String): T + public fun create(name: String): T } internal operator fun Provider.getValue(obj: Any?, property: KProperty<*>): T = create(property.name) @@ -25,4 +25,4 @@ internal operator fun Provider.getValue(obj: Any?, property: KProperty<*> * val myAggregator by MyAggregator.Factory * ``` */ -internal fun interface AggregatorProvider : Provider> +public fun interface AggregatorProvider : Provider> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 993f32104e..d5b044eb12 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -25,8 +25,7 @@ import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion import org.jetbrains.kotlinx.dataframe.math.sum import org.jetbrains.kotlinx.dataframe.math.sumTypeConversion -@PublishedApi -internal object Aggregators { +public object Aggregators { // TODO these might need some small refactoring @@ -110,7 +109,7 @@ internal object Aggregators { // T: Comparable -> T? // T : Comparable? -> T? - fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() + public fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() private val min by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( @@ -122,7 +121,7 @@ internal object Aggregators { // T: Comparable -> T? // T : Comparable? -> T? - fun ?> max(skipNaN: Boolean): Aggregator = max.invoke(skipNaN).cast2() + public fun ?> max(skipNaN: Boolean): Aggregator = max.invoke(skipNaN).cast2() private val max by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( @@ -133,7 +132,7 @@ internal object Aggregators { } // T: Number? -> Double - val std by withTwoOptions { skipNaN: Boolean, ddof: Int -> + public val std: AggregatorOptionSwitch2 by withTwoOptions { skipNaN: Boolean, ddof: Int -> flattenReducingForNumbers(stdTypeConversion) { type -> std(type, skipNaN, ddof) } @@ -141,14 +140,14 @@ internal object Aggregators { // step one: T: Number? -> Double // step two: Double -> Double - val mean by withOneOption { skipNaN: Boolean -> + public val mean: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> twoStepReducingForNumbers(meanTypeConversion) { type -> mean(type, skipNaN) } } // T: Comparable? -> T - val percentile by withOneOption { percentile: Double -> + public val percentile: AggregatorOptionSwitch1, Comparable?> by withOneOption { percentile: Double -> flattenReducingForAny> { type -> asIterable().percentile(percentile, type) } @@ -156,17 +155,17 @@ internal object Aggregators { // T : primitive Number? -> Double? // T : Comparable? -> T? - fun medianCommon(skipNaN: Boolean): Aggregator + public fun medianCommon(skipNaN: Boolean): Aggregator where T : Comparable? = median.invoke(skipNaN).cast2() // T : Comparable? -> T? - fun medianComparables(): Aggregator + public fun medianComparables(): Aggregator where T : Comparable? = medianCommon(skipNaNDefault).cast2() // T : primitive Number? -> Double? - fun medianNumbers( + public fun medianNumbers( skipNaN: Boolean, ): Aggregator where T : Comparable?, T : Number? = @@ -182,7 +181,7 @@ internal object Aggregators { } // T: Number -> T - val sum by withOneOption { skipNaN: Boolean -> + public val sum: AggregatorOptionSwitch1 by withOneOption { skipNaN: Boolean -> twoStepReducingForNumbers(sumTypeConversion) { type -> sum(type, skipNaN) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt index be886c9578..91157483ff 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType.kt @@ -10,6 +10,6 @@ import kotlin.reflect.KType * for the values to become the correct value type. If `false`, the values are already the right type, * or a simple cast will suffice. */ -internal data class ValueType(val kType: KType, val needsFullConversion: Boolean = false) +public data class ValueType(val kType: KType, val needsFullConversion: Boolean = false) internal fun KType.toValueType(needsFullConversion: Boolean = false): ValueType = ValueType(this, needsFullConversion) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt index e53e6470a4..c24a3e34a8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/multipleColumnsHandlers/TwoStepMultipleColumnsHandler.kt @@ -31,7 +31,7 @@ import kotlin.reflect.KType * If not supplied, the handler of the first step is reused. * @see [FlatteningMultipleColumnsHandler] */ -internal class TwoStepMultipleColumnsHandler( +internal class TwoStepMultipleColumnsHandler( stepTwoAggregationHandler: AggregatorAggregationHandler? = null, stepTwoInputHandler: AggregatorInputHandler? = null, ) : AggregatorMultipleColumnsHandler { From 644685dc50b0e946d01409fcbdd3f5c2a4820139 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 23 Apr 2025 18:33:41 +0200 Subject: [PATCH 3/8] Enhance type conversions between `KType` and `ConeKotlinType` to ensure compatibility and correctness in sum calculations. --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 2 + .../dataframe/plugin/impl/api/statistics.kt | 102 ++++++++++++++++-- .../dataframe/plugin/loadInterpreter.kt | 2 + plugins/kotlin-dataframe/testData/box/sum.kt | 27 ++--- 4 files changed, 112 insertions(+), 21 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 10b38c100a..c0d9dae69c 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -130,6 +130,8 @@ public fun AnyRow.rowSumOf(type: KType, skipNaN: Boolean = skipNaNDefault): Numb public fun DataFrame.sum(skipNaN: Boolean = skipNaNDefault): DataRow = sumFor(skipNaN, primitiveOrMixedNumberColumns()) +@Refine +@Interpretable("Sum1") public fun DataFrame.sumFor( skipNaN: Boolean = skipNaNDefault, columns: ColumnsForAggregateSelector, diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt index d54cf7c4fa..ddab8f7259 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt @@ -1,12 +1,86 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api +import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag +import org.jetbrains.kotlin.fir.types.ConeClassLikeErrorLookupTag import org.jetbrains.kotlin.fir.types.isSubtypeOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame import org.jetbrains.kotlinx.dataframe.plugin.impl.simpleColumnOf +import kotlin.reflect.KType +import kotlin.reflect.full.createType +import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.ConeClassLikeType +import org.jetbrains.kotlin.fir.types.ConeNullability +import org.jetbrains.kotlin.fir.types.impl.ConeClassLikeTypeImpl +import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator +import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol +import kotlin.reflect.KClass + +fun ConeKotlinType.toKType(): KType? { + if (this is ConeClassLikeType) { + val classId: ClassId = this.lookupTag.classId + val isNullable = this.nullability == ConeNullability.NULLABLE + return when (classId.asString()) { + "kotlin/Int" -> Int::class.createType(nullable = isNullable) + "kotlin/Long" -> Long::class.createType(nullable = isNullable) + "kotlin/Double" -> Double::class.createType(nullable = isNullable) + "kotlin/Float" -> Float::class.createType(nullable = isNullable) + "kotlin/Short" -> Short::class.createType(nullable = isNullable) + "kotlin/Byte" -> Byte::class.createType(nullable = isNullable) + else -> null + } + } + return null +} + +fun KType.toConeKotlinType(): ConeKotlinType? { + val kClass = this.classifier as? KClass<*> ?: return null + + val classId = when (kClass) { + Int::class -> ClassId.fromString("kotlin/Int") + Long::class -> ClassId.fromString("kotlin/Long") + Double::class -> ClassId.fromString("kotlin/Double") + Float::class -> ClassId.fromString("kotlin/Float") + Short::class -> ClassId.fromString("kotlin/Short") + Byte::class -> ClassId.fromString("kotlin/Byte") + else -> return null + } + + val lookupTag: ConeClassLikeLookupTag = ConeClassLikeErrorLookupTag(classId) + + val nullability = if (this.isMarkedNullable) ConeNullability.NULLABLE else ConeNullability.NOT_NULL + + return ConeClassLikeTypeImpl( + lookupTag = lookupTag, + typeArguments = emptyArray(), + isNullable = (nullability == ConeNullability.NULLABLE) + ) +} + +private fun Arguments.calculateResultTypeForStatistic( + aggregator: Aggregator, + resolvedColumns: List +): List { + val newColumns = resolvedColumns + .map { col -> + val columnType = col.type.type + val inputColumnKType = columnType.toKType() + val calculatedReturnKType = inputColumnKType?.let { aggregator.calculateReturnType(it, emptyInput = false) } + val updatedType = + calculatedReturnKType?.toConeKotlinType() ?: columnType // stay with input type by default + simpleColumnOf(col.name, updatedType) + } + .toList() + return newColumns +} + +val skipNaN = true +val sum = Aggregators.sum(skipNaN) /** Adds to the schema only numerical columns. */ abstract class Aggregator0 : AbstractSchemaModificationInterpreter() { @@ -17,18 +91,28 @@ abstract class Aggregator0 : AbstractSchemaModificationInterpreter() { .filterIsInstance() .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } - /*val skipNaN = true - val sum = Aggregators.sum(skipNaN)*/ + val newColumns = calculateResultTypeForStatistic(sum, resolvedColumns) - /* val newColumns = resolvedColumns - .map { col -> - simpleColumnOf(col.name, session.builtinTypes.doubleType.type) - } - .toList()*/ - - return PluginDataFrameSchema(receiver.columns() + resolvedColumns) + return PluginDataFrameSchema(receiver.columns() + newColumns) } } /** Implementation for `sum`. */ class Sum0 : Aggregator0() + +/** Adds to the schema only numerical columns. */ +abstract class Aggregator1 : AbstractSchemaModificationInterpreter() { + private val Arguments.receiver: PluginDataFrameSchema by dataFrame() + private val Arguments.columns: ColumnsResolver by arg() + + override fun Arguments.interpret(): PluginDataFrameSchema { + val resolvedColumns = columns.resolve(receiver).map { it.column }.filterIsInstance().toList() + + val newColumns = calculateResultTypeForStatistic(sum, resolvedColumns) + + return PluginDataFrameSchema(receiver.columns() + newColumns) + } +} + +/** Implementation for `sum`. */ +class Sum1 : Aggregator1() diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt index 651dcb8a5c..33f1211cf5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/loadInterpreter.kt @@ -198,6 +198,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Single2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Sum0 +import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Sum1 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ValueCols2 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take0 import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Take1 @@ -449,6 +450,7 @@ internal inline fun String.load(): T { "Reorder" -> Reorder() "ByName" -> ByName() "Sum0" -> Sum0() + "Sum1" -> Sum1() "GroupByCount0" -> GroupByCount0() "GroupByMean0" -> GroupByMean0() "GroupByMean1" -> GroupByMean1() diff --git a/plugins/kotlin-dataframe/testData/box/sum.kt b/plugins/kotlin-dataframe/testData/box/sum.kt index 23c1147a94..c379542212 100644 --- a/plugins/kotlin-dataframe/testData/box/sum.kt +++ b/plugins/kotlin-dataframe/testData/box/sum.kt @@ -38,34 +38,37 @@ fun box(): String { val sum05: Int? = res0.dependentsCount val sum06: Long? = res0.annualIncome - /*// scenario #1: particular column + // scenario #1: particular column val res1 = personsDf.sumFor { age } val sum11: Int? = res1.age + // scenario #1.1: particular column with converted type + val res11 = personsDf.sumFor { dependentsCount } + val sum111: Int? = res1.age + // scenario #2: sum of all values in two columns of Int hierarchy of types val res2 = personsDf.sum { age and workExperienceYears } - val intRes: Int? = res2 + val intRes: Int? = (res2 as? Number)?.toInt() // scenario #2.1: sum of all values in two columns of different types val res21 = personsDf.sum { age and annualIncome } - val longRes: Long? = res21 + val longRes: Long? = (res21 as? Number)?.toLong() val res211 = personsDf.sum { age and weight } - val doubleRes: Double? = res211 + val doubleRes: Double? = (res211 as? Number)?.toDouble() // scenario #3: sum of values per columns separately val res3 = personsDf.sumFor { age and weight and workExperienceYears and dependentsCount and annualIncome } - res3.compareSchemas() - val sum31: Int? = res0.age - val sum32: Double? = res0.weight - val sum33: Int? = res0.workExperienceYears - val sum34: Int? = res0.dependentsCount - val sum35: Long? = res0.annualIncome + val sum31: Int? = res3.age + val sum32: Double? = res3.weight + val sum33: Int? = res3.workExperienceYears + val sum34: Int? = res3.dependentsCount + val sum35: Long? = res3.annualIncome // scenario #4: sum of expression evaluated for every row - val res4 = personsDf.sumOf { weight * 10 + age } - val expressionDoubleRes: Double? = res4*/ + val res4 = personsDf.sumOf { weight * 10 + dependentsCount } + val expressionDoubleRes: Double? = res4 return "OK" } From 2db823fd34e32067e99e02fcbb7e017e7ee48787 Mon Sep 17 00:00:00 2001 From: Aleksei Zinovev Date: Wed, 23 Apr 2025 18:39:26 +0200 Subject: [PATCH 4/8] Update plugins/kotlin-dataframe/testData/box/sum.kt Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- plugins/kotlin-dataframe/testData/box/sum.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/kotlin-dataframe/testData/box/sum.kt b/plugins/kotlin-dataframe/testData/box/sum.kt index c379542212..af0575d12d 100644 --- a/plugins/kotlin-dataframe/testData/box/sum.kt +++ b/plugins/kotlin-dataframe/testData/box/sum.kt @@ -44,7 +44,7 @@ fun box(): String { // scenario #1.1: particular column with converted type val res11 = personsDf.sumFor { dependentsCount } - val sum111: Int? = res1.age + val sum111: Int? = res11.dependentsCount // scenario #2: sum of all values in two columns of Int hierarchy of types val res2 = personsDf.sum { age and workExperienceYears } From 9844b7d46bf72bda510255976f225277a19768f1 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Wed, 23 Apr 2025 18:50:16 +0200 Subject: [PATCH 5/8] Refactor type conversion and column handling logic --- .../dataframe/plugin/impl/api/statistics.kt | 98 ++++++++++--------- 1 file changed, 52 insertions(+), 46 deletions(-) diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt index ddab8f7259..734d1c3390 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt @@ -1,6 +1,5 @@ package org.jetbrains.kotlinx.dataframe.plugin.impl.api -import org.jetbrains.kotlin.fir.symbols.ConeClassLikeLookupTag import org.jetbrains.kotlin.fir.types.ConeClassLikeErrorLookupTag import org.jetbrains.kotlin.fir.types.isSubtypeOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators @@ -21,62 +20,69 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleCol import kotlin.reflect.KClass +private object PrimitiveClassIds { + const val INT = "kotlin/Int" + const val LONG = "kotlin/Long" + const val DOUBLE = "kotlin/Double" + const val FLOAT = "kotlin/Float" + const val SHORT = "kotlin/Short" + const val BYTE = "kotlin/Byte" +} + +private fun KClass<*>.toClassId(): ClassId? = when (this) { + Int::class -> ClassId.fromString(PrimitiveClassIds.INT) + Long::class -> ClassId.fromString(PrimitiveClassIds.LONG) + Double::class -> ClassId.fromString(PrimitiveClassIds.DOUBLE) + Float::class -> ClassId.fromString(PrimitiveClassIds.FLOAT) + Short::class -> ClassId.fromString(PrimitiveClassIds.SHORT) + Byte::class -> ClassId.fromString(PrimitiveClassIds.BYTE) + else -> null +} + +private val primitiveTypeMap = mapOf( + PrimitiveClassIds.INT to Int::class, + PrimitiveClassIds.LONG to Long::class, + PrimitiveClassIds.DOUBLE to Double::class, + PrimitiveClassIds.FLOAT to Float::class, + PrimitiveClassIds.SHORT to Short::class, + PrimitiveClassIds.BYTE to Byte::class +) + fun ConeKotlinType.toKType(): KType? { - if (this is ConeClassLikeType) { - val classId: ClassId = this.lookupTag.classId - val isNullable = this.nullability == ConeNullability.NULLABLE - return when (classId.asString()) { - "kotlin/Int" -> Int::class.createType(nullable = isNullable) - "kotlin/Long" -> Long::class.createType(nullable = isNullable) - "kotlin/Double" -> Double::class.createType(nullable = isNullable) - "kotlin/Float" -> Float::class.createType(nullable = isNullable) - "kotlin/Short" -> Short::class.createType(nullable = isNullable) - "kotlin/Byte" -> Byte::class.createType(nullable = isNullable) - else -> null - } + return (this as? ConeClassLikeType)?.let { coneType -> + val nullable = coneType.nullability == ConeNullability.NULLABLE + primitiveTypeMap[coneType.lookupTag.classId.asString()] + ?.createType(nullable = nullable) } - return null } fun KType.toConeKotlinType(): ConeKotlinType? { val kClass = this.classifier as? KClass<*> ?: return null - - val classId = when (kClass) { - Int::class -> ClassId.fromString("kotlin/Int") - Long::class -> ClassId.fromString("kotlin/Long") - Double::class -> ClassId.fromString("kotlin/Double") - Float::class -> ClassId.fromString("kotlin/Float") - Short::class -> ClassId.fromString("kotlin/Short") - Byte::class -> ClassId.fromString("kotlin/Byte") - else -> return null - } - - val lookupTag: ConeClassLikeLookupTag = ConeClassLikeErrorLookupTag(classId) - - val nullability = if (this.isMarkedNullable) ConeNullability.NULLABLE else ConeNullability.NOT_NULL + val classId = kClass.toClassId() ?: return null return ConeClassLikeTypeImpl( - lookupTag = lookupTag, + lookupTag = ConeClassLikeErrorLookupTag(classId), typeArguments = emptyArray(), - isNullable = (nullability == ConeNullability.NULLABLE) + isNullable = this.isMarkedNullable ) } -private fun Arguments.calculateResultTypeForStatistic( - aggregator: Aggregator, - resolvedColumns: List +private fun Arguments.generateStatisticResultColumns( + statisticAggregator: Aggregator, + inputColumns: List ): List { - val newColumns = resolvedColumns - .map { col -> - val columnType = col.type.type - val inputColumnKType = columnType.toKType() - val calculatedReturnKType = inputColumnKType?.let { aggregator.calculateReturnType(it, emptyInput = false) } - val updatedType = - calculatedReturnKType?.toConeKotlinType() ?: columnType // stay with input type by default - simpleColumnOf(col.name, updatedType) - } - .toList() - return newColumns + return inputColumns.map { col -> createUpdatedColumn(col, statisticAggregator) } +} + +private fun Arguments.createUpdatedColumn( + column: SimpleDataColumn, + statisticAggregator: Aggregator +): SimpleCol { + val originalType = column.type.type + val inputKType = originalType.toKType() + val resultKType = inputKType?.let { statisticAggregator.calculateReturnType(it, emptyInput = false) } + val updatedType = resultKType?.toConeKotlinType() ?: originalType + return simpleColumnOf(column.name, updatedType) } val skipNaN = true @@ -91,7 +97,7 @@ abstract class Aggregator0 : AbstractSchemaModificationInterpreter() { .filterIsInstance() .filter { it.type.type.isSubtypeOf(session.builtinTypes.numberType.type, session) } - val newColumns = calculateResultTypeForStatistic(sum, resolvedColumns) + val newColumns = generateStatisticResultColumns(sum, resolvedColumns) return PluginDataFrameSchema(receiver.columns() + newColumns) } @@ -108,7 +114,7 @@ abstract class Aggregator1 : AbstractSchemaModificationInterpreter() { override fun Arguments.interpret(): PluginDataFrameSchema { val resolvedColumns = columns.resolve(receiver).map { it.column }.filterIsInstance().toList() - val newColumns = calculateResultTypeForStatistic(sum, resolvedColumns) + val newColumns = generateStatisticResultColumns(sum, resolvedColumns) return PluginDataFrameSchema(receiver.columns() + newColumns) } From fc42da38187808a9108dc580c544b5daf0eb8ee0 Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Sat, 26 Apr 2025 16:40:27 +0200 Subject: [PATCH 6/8] Fixed review --- .../aggregation/aggregators/Aggregator.kt | 2 +- .../aggregators/AggregatorOptionSwitch.kt | 4 ++-- .../kotlinx/dataframe/api/statistics.kt | 20 ++++++------------- .../dataframe/plugin/impl/api/statistics.kt | 6 +++--- plugins/kotlin-dataframe/testData/box/sum.kt | 17 +--------------- ...DataFrameBlackBoxCodegenTestGenerated.java | 20 +++++++++---------- 6 files changed, 22 insertions(+), 47 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt index 9d181a6804..a7cb2a8b14 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator.kt @@ -37,7 +37,7 @@ import kotlin.reflect.full.withNullability * @param Return The type of the resulting value. Can optionally be nullable. * @see [invoke] */ -public class Aggregator( +public class Aggregator( public val aggregationHandler: AggregatorAggregationHandler, public val inputHandler: AggregatorInputHandler, public val multipleColumnsHandler: AggregatorMultipleColumnsHandler, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt index 87a1a37460..f8853a64b6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt @@ -30,9 +30,9 @@ public class AggregatorOptionSwitch1 Factory( + public fun Factory( getAggregator: (param1: Param1) -> AggregatorProvider, - ) = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } + ): Provider> = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index 8e0e990ae4..db70a24130 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -62,18 +62,14 @@ class StatisticsTests { val sum11 = res1["age"] as Int sum11 shouldBe 377 - // scenario #2: sum of all values in two columns of Int hierarchy of types - val res2 = personsDf.sum("age", "workExperienceYears") - res2 shouldBe 563 + // scenario #1.1: particular column with converted type + val res11 = personsDf.sumFor("dependentsCount") + res11.columnNames() shouldBe listOf("dependentsCount") - // scenario #2.1: sum of all values in two columns of different types - val res21 = personsDf.sum("age", "annualIncome") - res21 shouldBe 672377L + val sum111 = res11["dependentsCount"] as Int + sum111 shouldBe 13 - val res211 = personsDf.sum("age", "weight") - res211 shouldBe 1118.9699999999998 - - // scenario #3: sum of values per columns separately + // scenario #2: sum of values per columns separately val res3 = personsDf.sumFor( "age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") res3.columnNames() shouldBe listOf("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") @@ -87,10 +83,6 @@ class StatisticsTests { sum34 shouldBe 13.0 val sum35 = res0["annualIncome"] as Long sum35 shouldBe 672000 - - // scenario #4: sum of expression evaluated for every row - val res4 = personsDf.sumOf { "weight"() * 10 + "age"() } - res4 shouldBe 7796.7 } @Test diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt index 734d1c3390..079556098d 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/statistics.kt @@ -14,6 +14,7 @@ import kotlin.reflect.full.createType import org.jetbrains.kotlin.fir.types.ConeKotlinType import org.jetbrains.kotlin.fir.types.ConeClassLikeType import org.jetbrains.kotlin.fir.types.ConeNullability +import org.jetbrains.kotlin.fir.types.constructClassLikeType import org.jetbrains.kotlin.fir.types.impl.ConeClassLikeTypeImpl import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator @@ -60,8 +61,7 @@ fun KType.toConeKotlinType(): ConeKotlinType? { val kClass = this.classifier as? KClass<*> ?: return null val classId = kClass.toClassId() ?: return null - return ConeClassLikeTypeImpl( - lookupTag = ConeClassLikeErrorLookupTag(classId), + return classId.constructClassLikeType( typeArguments = emptyArray(), isNullable = this.isMarkedNullable ) @@ -80,7 +80,7 @@ private fun Arguments.createUpdatedColumn( ): SimpleCol { val originalType = column.type.type val inputKType = originalType.toKType() - val resultKType = inputKType?.let { statisticAggregator.calculateReturnType(it, emptyInput = false) } + val resultKType = inputKType?.let { statisticAggregator.calculateReturnType(it, emptyInput = true) } val updatedType = resultKType?.toConeKotlinType() ?: originalType return simpleColumnOf(column.name, updatedType) } diff --git a/plugins/kotlin-dataframe/testData/box/sum.kt b/plugins/kotlin-dataframe/testData/box/sum.kt index af0575d12d..09d6a0b63f 100644 --- a/plugins/kotlin-dataframe/testData/box/sum.kt +++ b/plugins/kotlin-dataframe/testData/box/sum.kt @@ -46,18 +46,7 @@ fun box(): String { val res11 = personsDf.sumFor { dependentsCount } val sum111: Int? = res11.dependentsCount - // scenario #2: sum of all values in two columns of Int hierarchy of types - val res2 = personsDf.sum { age and workExperienceYears } - val intRes: Int? = (res2 as? Number)?.toInt() - - // scenario #2.1: sum of all values in two columns of different types - val res21 = personsDf.sum { age and annualIncome } - val longRes: Long? = (res21 as? Number)?.toLong() - - val res211 = personsDf.sum { age and weight } - val doubleRes: Double? = (res211 as? Number)?.toDouble() - - // scenario #3: sum of values per columns separately + // scenario #2: sum of values per columns separately val res3 = personsDf.sumFor { age and weight and workExperienceYears and dependentsCount and annualIncome } val sum31: Int? = res3.age @@ -66,9 +55,5 @@ fun box(): String { val sum34: Int? = res3.dependentsCount val sum35: Long? = res3.annualIncome - // scenario #4: sum of expression evaluated for every row - val res4 = personsDf.sumOf { weight * 10 + dependentsCount } - val expressionDoubleRes: Double? = res4 - return "OK" } diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index b062e52969..2b76623d3d 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -6,7 +6,6 @@ import org.jetbrains.kotlin.test.util.KtTestUtil; import org.jetbrains.kotlin.test.TargetBackend; import org.jetbrains.kotlin.test.TestMetadata; -import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; @@ -158,8 +157,8 @@ public void testDiff() { @Test @TestMetadata("distinct.kt") public void testDistinct() { - runTest("testData/box/distinct.kt"); - } + runTest("testData/box/distinct.kt"); + } @Test @TestMetadata("dropNA.kt") @@ -233,12 +232,6 @@ public void testFlexibleReturnType() { runTest("testData/box/flexibleReturnType.kt"); } - @Test - @TestMetadata("sum.kt") - public void testSum() { - runTest("testData/box/sum.kt"); - } - @Test @TestMetadata("group.kt") public void testGroup() { @@ -296,8 +289,7 @@ public void testGroupBy_mean() { @Test @TestMetadata("groupBy_median.kt") public void testGroupBy_median() { - Assumptions.assumeTrue(false, "ignoring median test while compiler plugin support is pending."); - runTest("testData/box/groupBy_median.kt"); + runTest("testData/box/groupBy_median.kt"); } @Test @@ -600,6 +592,12 @@ public void testSelectionDsl() { runTest("testData/box/selectionDsl.kt"); } + @Test + @TestMetadata("sum.kt") + public void testSum() { + runTest("testData/box/sum.kt"); + } + @Test @TestMetadata("toDataFrame.kt") public void testToDataFrame() { From 51822bea50b2cb033a78c4b3e0dda95586b61f3d Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Sat, 26 Apr 2025 17:04:56 +0200 Subject: [PATCH 7/8] Fixed conflict --- .../impl/aggregation/aggregators/Aggregators.kt | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 6a30903bb7..1c75d42d47 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -19,6 +19,9 @@ import org.jetbrains.kotlinx.dataframe.math.medianConversion import org.jetbrains.kotlinx.dataframe.math.medianOrNull import org.jetbrains.kotlinx.dataframe.math.minOrNull import org.jetbrains.kotlinx.dataframe.math.minTypeConversion +import org.jetbrains.kotlinx.dataframe.math.percentileConversion +import org.jetbrains.kotlinx.dataframe.math.percentileOrNull +import org.jetbrains.kotlinx.dataframe.math.indexOfPercentile import org.jetbrains.kotlinx.dataframe.math.std import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion import org.jetbrains.kotlinx.dataframe.math.sum @@ -108,7 +111,7 @@ public object Aggregators { // T: Comparable -> T? // T : Comparable? -> T? - fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() + public fun ?> min(skipNaN: Boolean): Aggregator = min.invoke(skipNaN).cast2() private val min by withOneOption { skipNaN: Boolean -> twoStepSelectingForAny, Comparable?>( @@ -147,7 +150,7 @@ public object Aggregators { // T : primitive Number? -> Double? // T : Comparable? -> T? - fun percentileCommon( + public fun percentileCommon( percentile: Double, skipNaN: Boolean, ): Aggregator @@ -155,12 +158,12 @@ public object Aggregators { this.percentile.invoke(percentile, skipNaN).cast2() // T : Comparable? -> T? - fun percentileComparables(percentile: Double): Aggregator + public fun percentileComparables(percentile: Double): Aggregator where T : Comparable? = percentileCommon(percentile, skipNaNDefault).cast2() // T : primitive Number? -> Double? - fun percentileNumbers( + public fun percentileNumbers( percentile: Double, skipNaN: Boolean, ): Aggregator From 224d65266d848802ea33d6da1434f51afa8e1a8c Mon Sep 17 00:00:00 2001 From: Alexey Zinoviev Date: Sat, 26 Apr 2025 18:12:16 +0200 Subject: [PATCH 8/8] Fix linting --- .../aggregators/AggregatorOptionSwitch.kt | 3 ++- .../aggregators/AggregatorProvider.kt | 3 ++- .../aggregation/aggregators/Aggregators.kt | 11 +++++--- .../kotlinx/dataframe/api/statistics.kt | 27 +++++++++---------- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt index f8853a64b6..75bd9865d6 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorOptionSwitch.kt @@ -32,7 +32,8 @@ public class AggregatorOptionSwitch1 Factory( getAggregator: (param1: Param1) -> AggregatorProvider, - ): Provider> = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } + ): Provider> = + Provider { name -> AggregatorOptionSwitch1(name, getAggregator) } } } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt index 158995cfea..0405f88f1e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorProvider.kt @@ -25,4 +25,5 @@ internal operator fun Provider.getValue(obj: Any?, property: KProperty<*> * val myAggregator by MyAggregator.Factory * ``` */ -public fun interface AggregatorProvider : Provider> +public fun interface AggregatorProvider : + Provider> diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt index 1c75d42d47..d754aa3145 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregators.kt @@ -11,6 +11,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.multipleColu import org.jetbrains.kotlinx.dataframe.math.indexOfMax import org.jetbrains.kotlinx.dataframe.math.indexOfMedian import org.jetbrains.kotlinx.dataframe.math.indexOfMin +import org.jetbrains.kotlinx.dataframe.math.indexOfPercentile import org.jetbrains.kotlinx.dataframe.math.maxOrNull import org.jetbrains.kotlinx.dataframe.math.maxTypeConversion import org.jetbrains.kotlinx.dataframe.math.mean @@ -21,7 +22,6 @@ import org.jetbrains.kotlinx.dataframe.math.minOrNull import org.jetbrains.kotlinx.dataframe.math.minTypeConversion import org.jetbrains.kotlinx.dataframe.math.percentileConversion import org.jetbrains.kotlinx.dataframe.math.percentileOrNull -import org.jetbrains.kotlinx.dataframe.math.indexOfPercentile import org.jetbrains.kotlinx.dataframe.math.std import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion import org.jetbrains.kotlinx.dataframe.math.sum @@ -134,7 +134,10 @@ public object Aggregators { } // T: Number? -> Double - public val std: AggregatorOptionSwitch2 by withTwoOptions { skipNaN: Boolean, ddof: Int -> + public val std: AggregatorOptionSwitch2 by withTwoOptions { + skipNaN: Boolean, + ddof: Int, + -> flattenReducingForNumbers(stdTypeConversion) { type -> std(type, skipNaN, ddof) } @@ -158,7 +161,9 @@ public object Aggregators { this.percentile.invoke(percentile, skipNaN).cast2() // T : Comparable? -> T? - public fun percentileComparables(percentile: Double): Aggregator + public fun percentileComparables( + percentile: Double, + ): Aggregator where T : Comparable? = percentileCommon(percentile, skipNaNDefault).cast2() diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt index db70a24130..726ed62a59 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/statistics.kt @@ -14,7 +14,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", )( "Alice", 15, "London", 99.5, "1.85", 50, 0.toShort(), 0.toByte(), 0L, "Bob", 20, "Paris", 140.0, "1.35", 45, 2.toShort(), 0.toByte(), 12000L, @@ -28,7 +28,6 @@ class StatisticsTests { "Isla", 22, "London", 75.1, "1.85", 43, 1.toShort(), 0.toByte(), 30000L, ) - @Test fun `sum on DataFrame`() { // scenario #0: all numerical columns @@ -39,7 +38,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val sum01 = res0["age"] as Int @@ -70,7 +69,7 @@ class StatisticsTests { sum111 shouldBe 13 // scenario #2: sum of values per columns separately - val res3 = personsDf.sumFor( "age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") + val res3 = personsDf.sumFor("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") res3.columnNames() shouldBe listOf("age", "weight", "workExperienceYears", "dependentsCount", "annualIncome") val sum31 = res3["age"] as Int @@ -96,7 +95,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val sum01 = res0["age"][0] as Int @@ -165,7 +164,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val mean01 = res0["age"][0] as Double @@ -236,7 +235,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val median01 = res0["age"][0] as Double @@ -305,7 +304,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val std01 = res0["age"][0] as Double @@ -376,7 +375,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val min01 = res0["age"][0] as Int @@ -444,7 +443,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) // TODO: why is here weight presented? looks like inconsitency val min41 = res4["age"][0] as Int @@ -463,7 +462,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val min51 = res5["age"][0] as Int @@ -483,7 +482,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val max01 = res0["age"][0] as Int @@ -551,7 +550,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) // TODO: weight is here? val max41 = res4["age"][0] as Int @@ -570,7 +569,7 @@ class StatisticsTests { "yearsToRetirement", "workExperienceYears", "dependentsCount", - "annualIncome" + "annualIncome", ) val max51 = res5["age"][0] as Int