Skip to content

[Compiler plugin] Support unfold #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -4328,10 +4328,11 @@ public final class org/jetbrains/kotlinx/dataframe/api/TypeConversionsKt {
}

public final class org/jetbrains/kotlinx/dataframe/api/UnfoldKt {
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KCallable;ILkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static final fun unfold (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
public static synthetic fun unfold$default (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KCallable;ILkotlin/jvm/functions/Function2;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}

public final class org/jetbrains/kotlinx/dataframe/api/UngroupKt {
Expand Down Expand Up @@ -5611,6 +5612,10 @@ public final class org/jetbrains/kotlinx/dataframe/impl/api/ToSequenceKt {
public static final fun toSequenceImpl (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/reflect/KType;)Lkotlin/sequences/Sequence;
}

public final class org/jetbrains/kotlinx/dataframe/impl/api/UnfoldKt {
public static final fun unfoldImpl (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)Lorg/jetbrains/kotlinx/dataframe/DataColumn;
}

public final class org/jetbrains/kotlinx/dataframe/impl/api/UpdateKt {
public static final fun updateImpl (Lorg/jetbrains/kotlinx/dataframe/api/Update;Lkotlin/jvm/functions/Function3;)Lorg/jetbrains/kotlinx/dataframe/DataFrame;
}
Expand Down
31 changes: 13 additions & 18 deletions core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/unfold.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,23 @@ import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.AccessApiOverload
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.annotations.Interpretable
import org.jetbrains.kotlinx.dataframe.annotations.Refine
import org.jetbrains.kotlinx.dataframe.columns.toColumnSet
import org.jetbrains.kotlinx.dataframe.impl.api.canBeUnfolded
import org.jetbrains.kotlinx.dataframe.impl.api.createDataFrameImpl
import org.jetbrains.kotlinx.dataframe.typeClass
import org.jetbrains.kotlinx.dataframe.impl.api.unfoldImpl
import kotlin.reflect.KCallable
import kotlin.reflect.KProperty

public inline fun <reified T> DataColumn<T>.unfold(): AnyCol =
when (kind()) {
ColumnKind.Group, ColumnKind.Frame -> this
public inline fun <reified T> DataColumn<T>.unfold(vararg roots: KCallable<*>, maxDepth: Int = 0): AnyCol =
unfoldImpl { properties(roots = roots, maxDepth) }

else -> when {
!typeClass.canBeUnfolded -> this

else -> values()
.createDataFrameImpl(typeClass) { (this as CreateDataFrameDsl<T>).properties() }
.asColumnGroup(name())
.asDataColumn()
}
}

public fun <T> DataFrame<T>.unfold(columns: ColumnsSelector<T, *>): DataFrame<T> = replace(columns).with { it.unfold() }
@Refine
@Interpretable("DataFrameUnfold")
public fun <T> DataFrame<T>.unfold(
vararg roots: KCallable<*>,
maxDepth: Int = 0,
columns: ColumnsSelector<T, *>,
): DataFrame<T> = replace(columns).with { it.unfoldImpl { properties(roots = roots, maxDepth) } }

public fun <T> DataFrame<T>.unfold(vararg columns: String): DataFrame<T> = unfold { columns.toColumnSet() }

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package org.jetbrains.kotlinx.dataframe.impl.api

import org.jetbrains.kotlinx.dataframe.AnyCol
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.CreateDataFrameDsl
import org.jetbrains.kotlinx.dataframe.api.asColumnGroup
import org.jetbrains.kotlinx.dataframe.api.asDataColumn
import org.jetbrains.kotlinx.dataframe.columns.ColumnKind
import org.jetbrains.kotlinx.dataframe.typeClass

@PublishedApi
internal fun <T> DataColumn<T>.unfoldImpl(body: CreateDataFrameDsl<T>.() -> Unit): AnyCol =
when (kind()) {
ColumnKind.Group, ColumnKind.Frame -> this

else -> when {
!typeClass.canBeUnfolded -> this

else -> values()
.createDataFrameImpl(typeClass) { (this as CreateDataFrameDsl<T>).body() }
.asColumnGroup(name())
.asDataColumn()
}
}
68 changes: 68 additions & 0 deletions core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/unfold.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package org.jetbrains.kotlinx.dataframe.api

import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeInstanceOf
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.junit.Test
import kotlin.reflect.typeOf

class UnfoldTests {
@Test
fun unfold() {
val df = dataFrameOf(
"col" to listOf(A("123", 321)),
)

val res = df.unfold { col("col") }
res[pathOf("col", "str")][0] shouldBe "123"
res[pathOf("col", "i")][0] shouldBe 321
}

@Test
fun `unfold deep`() {
val df1 = dataFrameOf(
"col" to listOf(
Group(
"1",
listOf(
Person("Alice", "Cooper", 15, "London"),
Person("Bob", "Dylan", 45, "Dubai"),
),
),
Group(
"2",
listOf(
Person("Charlie", "Daniels", 20, "Moscow"),
Person("Charlie", "Chaplin", 40, "Milan"),
),
),
),
)

df1.unfold { col("col") }[pathOf("col", "participants")].type() shouldBe typeOf<List<Person>>()

df1.unfold(maxDepth = 2) { col("col") }[pathOf("col", "participants")][0].shouldBeInstanceOf<AnyFrame> {
it["firstName"][0] shouldBe "Alice"
}
}

@Test
fun `keep value type`() {
val values = listOf(1, 2, 3, 4)
val df2 = dataFrameOf("int" to values)
val column = df2.unfold { col("int") }["int"]
column.type() shouldBe typeOf<Int>()
column.values() shouldBe values
}

data class A(val str: String, val i: Int)

data class Person(
val firstName: String,
val lastName: String,
val age: Int,
val city: String?,
)

data class Group(val id: String, val participants: List<Person>)
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,61 +197,6 @@ internal fun KotlinTypeFacade.toDataFrame(
arg: ConeTypeProjection,
traverseConfiguration: TraverseConfiguration,
): PluginDataFrameSchema {

val anyType = session.builtinTypes.nullableAnyType.type

fun ConeKotlinType.isValueType() =
this.isArrayTypeOrNullableArrayType ||
this.classId == StandardClassIds.Unit ||
this.classId == StandardClassIds.Any ||
this.classId == StandardClassIds.Map ||
this.classId == StandardClassIds.MutableMap ||
this.classId == StandardClassIds.String ||
this.classId in StandardClassIds.primitiveTypes ||
this.classId in StandardClassIds.unsignedTypes ||
classId in setOf(
Names.DURATION_CLASS_ID,
Names.LOCAL_DATE_CLASS_ID,
Names.LOCAL_DATE_TIME_CLASS_ID,
Names.INSTANT_CLASS_ID,
Names.DATE_TIME_PERIOD_CLASS_ID,
Names.DATE_TIME_UNIT_CLASS_ID,
Names.TIME_ZONE_CLASS_ID
) ||
this.isSubtypeOf(
StandardClassIds.Number.constructClassLikeType(emptyArray(), isNullable = true),
session
) ||
this.toRegularClassSymbol(session)?.isEnumClass ?: false ||
this.isSubtypeOf(
Names.TEMPORAL_ACCESSOR_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
) ||
this.isSubtypeOf(
Names.TEMPORAL_AMOUNT_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
)


fun FirNamedFunctionSymbol.isGetterLike(): Boolean {
val functionName = this.name.asString()
return (functionName.startsWith("get") || functionName.startsWith("is")) &&
this.valueParameterSymbols.isEmpty() &&
this.typeParameterSymbols.isEmpty()
}

fun ConeKotlinType.hasProperties(): Boolean {
val symbol = this.toRegularClassSymbol(session) as? FirClassSymbol<*> ?: return false
val scope = symbol.unsubstitutedScope(
session,
ScopeSession(),
withForcedTypeCalculator = false,
memberRequiredPhase = null
)

return scope.collectAllProperties().any { it.visibility == Visibilities.Public } ||
scope.collectAllFunctions().any { it.visibility == Visibilities.Public && it.isGetterLike() }
}


val excludes =
traverseConfiguration.excludeProperties.mapNotNullTo(mutableSetOf()) { it.calleeReference.toResolvedPropertySymbol() }
val excludedClasses = traverseConfiguration.excludeClasses.mapTo(mutableSetOf()) { it.argument.resolvedType }
Expand Down Expand Up @@ -322,7 +267,7 @@ internal fun KotlinTypeFacade.toDataFrame(

val keepSubtree =
depth >= maxDepth && !fieldKind.shouldBeConvertedToColumnGroup && !fieldKind.shouldBeConvertedToFrameColumn
if (keepSubtree || returnType.isValueType() || returnType.classId in preserveClasses || it in preserveProperties) {
if (keepSubtree || returnType.isValueType(session) || returnType.classId in preserveClasses || it in preserveProperties) {
SimpleDataColumn(
name,
TypeApproximation(
Expand All @@ -349,7 +294,7 @@ internal fun KotlinTypeFacade.toDataFrame(
ConeStarProjection -> session.builtinTypes.nullableAnyType.type
else -> session.builtinTypes.nullableAnyType.type
}
if (type.isValueType()) {
if (type.isValueType(session)) {
val columnType = List.constructClassLikeType(arrayOf(type), returnType.isNullable)
.withNullability(ConeNullability.create(makeNullable), session.typeContext)
.wrap()
Expand All @@ -364,7 +309,7 @@ internal fun KotlinTypeFacade.toDataFrame(
}

arg.type?.let { type ->
if (type.isValueType() || !type.hasProperties()) {
if (!type.canBeUnfolded(session)) {
return PluginDataFrameSchema(listOf(simpleColumnOf("value", type)))
}
}
Expand All @@ -383,6 +328,60 @@ internal fun KotlinTypeFacade.toDataFrame(
}
}

fun ConeKotlinType.canBeUnfolded(session: FirSession): Boolean =
!isValueType(session) && hasProperties(session)

private fun ConeKotlinType.isValueType(session: FirSession) =
this.isArrayTypeOrNullableArrayType ||
this.classId == StandardClassIds.Unit ||
this.classId == StandardClassIds.Any ||
this.classId == StandardClassIds.Map ||
this.classId == StandardClassIds.MutableMap ||
this.classId == StandardClassIds.String ||
this.classId in StandardClassIds.primitiveTypes ||
this.classId in StandardClassIds.unsignedTypes ||
classId in setOf(
Names.DURATION_CLASS_ID,
Names.LOCAL_DATE_CLASS_ID,
Names.LOCAL_DATE_TIME_CLASS_ID,
Names.INSTANT_CLASS_ID,
Names.DATE_TIME_PERIOD_CLASS_ID,
Names.DATE_TIME_UNIT_CLASS_ID,
Names.TIME_ZONE_CLASS_ID
) ||
this.isSubtypeOf(
StandardClassIds.Number.constructClassLikeType(emptyArray(), isNullable = true),
session
) ||
this.toRegularClassSymbol(session)?.isEnumClass ?: false ||
this.isSubtypeOf(
Names.TEMPORAL_ACCESSOR_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
) ||
this.isSubtypeOf(
Names.TEMPORAL_AMOUNT_CLASS_ID.constructClassLikeType(emptyArray(), isNullable = true), session
)


private fun ConeKotlinType.hasProperties(session: FirSession): Boolean {
val symbol = this.toRegularClassSymbol(session) as? FirClassSymbol<*> ?: return false
val scope = symbol.unsubstitutedScope(
session,
ScopeSession(),
withForcedTypeCalculator = false,
memberRequiredPhase = null
)

return scope.collectAllProperties().any { it.visibility == Visibilities.Public } ||
scope.collectAllFunctions().any { it.visibility == Visibilities.Public && it.isGetterLike() }
}

private fun FirNamedFunctionSymbol.isGetterLike(): Boolean {
val functionName = this.name.asString()
return (functionName.startsWith("get") || functionName.startsWith("is")) &&
this.valueParameterSymbols.isEmpty() &&
this.typeParameterSymbols.isEmpty()
}

// org.jetbrains.kotlinx.dataframe.codeGen.getFieldKind
private fun ConeKotlinType.getFieldKind(session: FirSession) = FieldKind.of(
this,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.jetbrains.kotlinx.dataframe.plugin.impl.api

import org.jetbrains.kotlinx.dataframe.api.replace
import org.jetbrains.kotlinx.dataframe.api.with
import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter
import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.Present
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.asDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.asSimpleColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.dataFrame
import org.jetbrains.kotlinx.dataframe.plugin.impl.ignore
import org.jetbrains.kotlinx.dataframe.plugin.impl.toPluginDataFrameSchema

class DataFrameUnfold : AbstractSchemaModificationInterpreter() {
val Arguments.receiver: PluginDataFrameSchema by dataFrame()
val Arguments.properties by ignore()
val Arguments.maxDepth: Int by arg(defaultValue = Present(0))
val Arguments.columns: ColumnsResolver by arg()

override fun Arguments.interpret(): PluginDataFrameSchema {
return receiver.asDataFrame().replace { columns }.with {
val column = it.asSimpleColumn() as? SimpleDataColumn
if (column != null) {
if (!column.type.type.canBeUnfolded(session)) {
it
} else {
SimpleColumnGroup(it.name(), toDataFrame(maxDepth, column.type.type, TraverseConfiguration()).columns()).asDataColumn()
}
} else {
it
}
}.toPluginDataFrameSchema()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.api.ConcatWithKeys
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameBuilderInvoke0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameOf3
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameUnfold
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.DataFrameXs
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop0
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.Drop1
Expand Down Expand Up @@ -472,6 +473,7 @@ internal inline fun <reified T> String.load(): T {
"DataFrameXs" -> DataFrameXs()
"GroupByXs" -> GroupByXs()
"ConcatWithKeys" -> ConcatWithKeys()
"DataFrameUnfold" -> DataFrameUnfold()
else -> error("$this")
} as T
}
Loading
Loading