diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateDsl.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateDsl.kt index 8bb789d91d..6bd7c2052b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateDsl.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/aggregation/AggregateDsl.kt @@ -1,6 +1,7 @@ package org.jetbrains.kotlinx.dataframe.aggregation import org.jetbrains.kotlinx.dataframe.DataFrame +import org.jetbrains.kotlinx.dataframe.annotations.HasSchema import org.jetbrains.kotlinx.dataframe.annotations.Interpretable import org.jetbrains.kotlinx.dataframe.api.ColumnSelectionDsl import org.jetbrains.kotlinx.dataframe.api.pathOf @@ -11,6 +12,7 @@ import org.jetbrains.kotlinx.dataframe.impl.columnName import kotlin.reflect.KProperty import kotlin.reflect.typeOf +@HasSchema(schemaArg = 0) public abstract class AggregateDsl : DataFrame, ColumnSelectionDsl { diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/IrBodyFiller.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/IrBodyFiller.kt index 66ff2c9a57..4e81d8764d 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/IrBodyFiller.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/extensions/IrBodyFiller.kt @@ -30,6 +30,7 @@ import org.jetbrains.kotlin.ir.expressions.IrConst import org.jetbrains.kotlin.ir.expressions.IrErrorCallExpression import org.jetbrains.kotlin.ir.expressions.IrExpression import org.jetbrains.kotlin.ir.expressions.IrTypeOperator +import org.jetbrains.kotlin.ir.expressions.IrTypeOperatorCall import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl @@ -41,6 +42,7 @@ import org.jetbrains.kotlin.ir.expressions.impl.IrTypeOperatorCallImpl import org.jetbrains.kotlin.ir.symbols.IrValueSymbol import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI import org.jetbrains.kotlin.ir.types.IrSimpleType +import org.jetbrains.kotlin.ir.types.IrType import org.jetbrains.kotlin.ir.types.classFqName import org.jetbrains.kotlin.ir.types.classOrFail import org.jetbrains.kotlin.ir.types.classifierOrNull @@ -235,16 +237,32 @@ private class DataFrameFileLowering(val context: IrPluginContext) : FileLowering return true } - @OptIn(UnsafeDuringIrConstructionAPI::class) + // org.jetbrains.kotlin.fir.backend.generators.CallAndReferenceGenerator#applyReceivers + override fun visitTypeOperator(expression: IrTypeOperatorCall): IrExpression { + if (isScope(expression.typeOperand)) { + return expression.replaceWithConstructorCall() + } + return super.visitTypeOperator(expression) + } + override fun visitErrorCallExpression(expression: IrErrorCallExpression): IrExpression { - val origin = (expression.type.classifierOrNull?.owner as? IrClass)?.origin ?: return expression - val fromPlugin = origin is IrDeclarationOrigin.GeneratedByPlugin && origin.pluginKey is DataFramePlugin - val scopeReference = expression.type.classFqName?.shortName()?.asString()?.startsWith("Scope") ?: false - if (!(fromPlugin || scopeReference)) { + if (!isScope(expression.type)) { return expression } - val constructor = expression.type.getClass()!!.constructors.toList().single() - val type = expression.type + return expression.replaceWithConstructorCall() + } + + @OptIn(UnsafeDuringIrConstructionAPI::class) + private fun isScope(type: IrType): Boolean { + val origin = (type.classifierOrNull?.owner as? IrClass)?.origin ?: return false + val fromPlugin = origin is IrDeclarationOrigin.GeneratedByPlugin && origin.pluginKey is DataFramePlugin + val scopeReference = type.classFqName?.shortName()?.asString()?.startsWith("Scope") ?: false + return fromPlugin || scopeReference + } + + @OptIn(UnsafeDuringIrConstructionAPI::class) + private fun IrExpression.replaceWithConstructorCall(): IrConstructorCallImpl { + val constructor = type.getClass()!!.constructors.toList().single() return IrConstructorCallImpl(-1, -1, type, constructor.symbol, 0, 0, 0) } } diff --git a/plugins/kotlin-dataframe/testData/box/wrongReceiver.kt b/plugins/kotlin-dataframe/testData/box/wrongReceiver.kt new file mode 100644 index 0000000000..b06d4ff0c9 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/wrongReceiver.kt @@ -0,0 +1,14 @@ +import org.jetbrains.kotlinx.dataframe.annotations.DataSchema +import org.jetbrains.kotlinx.dataframe.api.* + + +@DataSchema +data class Record(val a: String, val b: Int) + +fun box(): String { + val df = List(10) { Record(it.toString(), it) }.let { dataFrameOf(*it.toTypedArray()) } + val aggregate = df.pivot { b }.aggregate { + this.add("c") { 123 }.c + } + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index ffcdd2a659..6d6f9bcb9c 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -442,6 +442,12 @@ public void testUpdate() { runTest("testData/box/update.kt"); } + @Test + @TestMetadata("wrongReceiver.kt") + public void testWrongReceiver() { + runTest("testData/box/wrongReceiver.kt"); + } + @Nested @TestMetadata("testData/box/colKinds") @TestDataPath("$PROJECT_ROOT")