Skip to content

Add checkers that report compile time schema as info warnings to observe implicit schema generation #1051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class FirDataFrameExtensionRegistrar(
private val path: String?,
val schemasDirectory: String?,
val isTest: Boolean,
val dumpSchemas: Boolean,
) : FirExtensionRegistrar() {
@OptIn(FirExtensionApiInternals::class)
override fun ExtensionRegistrarContext.configurePlugin() {
Expand All @@ -76,7 +77,7 @@ class FirDataFrameExtensionRegistrar(
+::TokenGenerator
+::DataRowSchemaSupertype
+{ it: FirSession ->
ExpressionAnalysisAdditionalChecker(it, jsonCache(it), schemasDirectory, isTest)
ExpressionAnalysisAdditionalChecker(it, jsonCache(it), schemasDirectory, isTest, dumpSchemas)
}
}

Expand All @@ -93,7 +94,9 @@ class FirDataFrameComponentRegistrar : CompilerPluginRegistrar() {
override fun ExtensionStorage.registerExtensions(configuration: CompilerConfiguration) {
val schemasDirectory = configuration.get(SCHEMAS)
val path = configuration.get(PATH)
FirExtensionRegistrarAdapter.registerExtension(FirDataFrameExtensionRegistrar(path, schemasDirectory, isTest = false))
FirExtensionRegistrarAdapter.registerExtension(
FirDataFrameExtensionRegistrar(path, schemasDirectory, isTest = false, dumpSchemas = true)
)
IrGenerationExtension.registerExtension(IrBodyFiller(path, schemasDirectory))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,38 @@

package org.jetbrains.kotlinx.dataframe.plugin.extensions

import com.intellij.psi.PsiElement
import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.diagnostics.AbstractSourceElementPositioningStrategy
import org.jetbrains.kotlin.diagnostics.DiagnosticFactory1DelegateProvider
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
import org.jetbrains.kotlin.diagnostics.KtDiagnosticFactory1
import org.jetbrains.kotlin.diagnostics.Severity
import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategies
import org.jetbrains.kotlin.diagnostics.error1
import org.jetbrains.kotlin.diagnostics.reportOn
import org.jetbrains.kotlin.diagnostics.warning1
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.DeclarationCheckers
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirPropertyChecker
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirSimpleFunctionChecker
import org.jetbrains.kotlin.fir.analysis.checkers.expression.ExpressionCheckers
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirPropertyAccessExpressionChecker
import org.jetbrains.kotlin.fir.analysis.extensions.FirAdditionalCheckersExtension
import org.jetbrains.kotlin.fir.caches.FirCache
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.flatten
import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.FirPropertyAccessExpression
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.fullyExpandedType
import org.jetbrains.kotlin.fir.types.ConeClassLikeType
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.FirTypeProjectionWithVariance
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isSubtypeOf
Expand All @@ -38,18 +50,34 @@ import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.psi.KtElement
import org.jetbrains.kotlinx.dataframe.plugin.impl.PluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.type
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.flatten
import org.jetbrains.kotlinx.dataframe.plugin.pluginDataFrameSchema
import org.jetbrains.kotlinx.dataframe.plugin.utils.Names
import org.jetbrains.kotlinx.dataframe.plugin.utils.isDataFrame
import org.jetbrains.kotlinx.dataframe.plugin.utils.isDataRow
import org.jetbrains.kotlinx.dataframe.plugin.utils.isGroupBy

class ExpressionAnalysisAdditionalChecker(
session: FirSession,
cache: FirCache<String, PluginDataFrameSchema, KotlinTypeFacade>,
schemasDirectory: String?,
isTest: Boolean,
dumpSchemas: Boolean
) : FirAdditionalCheckersExtension(session) {
override val expressionCheckers: ExpressionCheckers = object : ExpressionCheckers() {
override val functionCallCheckers: Set<FirFunctionCallChecker> = setOf(Checker(cache, schemasDirectory, isTest))
override val functionCallCheckers: Set<FirFunctionCallChecker> = setOfNotNull(
Checker(cache, schemasDirectory, isTest), FunctionCallSchemaReporter.takeIf { dumpSchemas }
)
override val propertyAccessExpressionCheckers: Set<FirPropertyAccessExpressionChecker> = setOfNotNull(
PropertyAccessSchemaReporter.takeIf { dumpSchemas }
)
}
override val declarationCheckers: DeclarationCheckers = object : DeclarationCheckers() {
override val propertyCheckers: Set<FirPropertyChecker> = setOfNotNull(PropertySchemaReporter.takeIf { dumpSchemas })
override val simpleFunctionCheckers: Set<FirSimpleFunctionChecker> = setOfNotNull(FunctionDeclarationSchemaReporter.takeIf { dumpSchemas })
}
}

Expand Down Expand Up @@ -132,3 +160,105 @@ private class Checker(
}
}
}

private data object PropertySchemaReporter : FirPropertyChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.DECLARATION_NAME)

override fun check(declaration: FirProperty, context: CheckerContext, reporter: DiagnosticReporter) {
context.sessionContext {
declaration.returnTypeRef.coneType.let { type ->
reportSchema(reporter, declaration.source, SCHEMA, type, context)
}
}
}
}

private data object FunctionCallSchemaReporter : FirFunctionCallChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.REFERENCED_NAME_BY_QUALIFIED)

override fun check(expression: FirFunctionCall, context: CheckerContext, reporter: DiagnosticReporter) {
if (expression.calleeReference.name in setOf(Name.identifier("let"), Name.identifier("run"))) return
val initializer = expression.resolvedType
context.sessionContext {
reportSchema(reporter, expression.source, SCHEMA, initializer, context)
}
}
}

private data object PropertyAccessSchemaReporter : FirPropertyAccessExpressionChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.REFERENCED_NAME_BY_QUALIFIED)

override fun check(
expression: FirPropertyAccessExpression,
context: CheckerContext,
reporter: DiagnosticReporter
) {
val initializer = expression.resolvedType
context.sessionContext {
reportSchema(reporter, expression.source, SCHEMA, initializer, context)
}
}
}

private data object FunctionDeclarationSchemaReporter : FirSimpleFunctionChecker(mppKind = MppCheckerKind.Common) {
val SCHEMA by info1<KtElement, String>(SourceElementPositioningStrategies.DECLARATION_SIGNATURE)

override fun check(declaration: FirSimpleFunction, context: CheckerContext, reporter: DiagnosticReporter) {
val type = declaration.returnTypeRef.coneType
context.sessionContext {
reportSchema(reporter, declaration.source, SCHEMA, type, context)
}
}
}

private fun SessionContext.reportSchema(
reporter: DiagnosticReporter,
source: KtSourceElement?,
factory: KtDiagnosticFactory1<String>,
type: ConeKotlinType,
context: CheckerContext,
) {
val expandedType = type.fullyExpandedType(session)
var schema: PluginDataFrameSchema? = null
when {
expandedType.isDataFrame(session) -> {
schema = expandedType.typeArguments.getOrNull(0)?.let {
pluginDataFrameSchema(it)
}
}

expandedType.isDataRow(session) -> {
schema = expandedType.typeArguments.getOrNull(0)?.let {
pluginDataFrameSchema(it)
}
}

expandedType.isGroupBy(session) -> {
val keys = expandedType.typeArguments.getOrNull(0)
val grouped = expandedType.typeArguments.getOrNull(1)
if (keys != null && grouped != null) {
val keysSchema = pluginDataFrameSchema(keys)
val groupedSchema = pluginDataFrameSchema(grouped)
schema = PluginDataFrameSchema(
listOf(
SimpleColumnGroup("keys", keysSchema.columns()),
SimpleFrameColumn("groups", groupedSchema.columns())
)
)
}
}
}
if (schema != null && source != null) {
reporter.reportOn(source, factory, "\n" + schema.toString(), context)
}
}

fun CheckerContext.sessionContext(f: SessionContext.() -> Unit) {
SessionContext(session).f()
}

inline fun <reified P : PsiElement, A> info1(
positioningStrategy: AbstractSourceElementPositioningStrategy = SourceElementPositioningStrategies.DEFAULT
): DiagnosticFactory1DelegateProvider<A> {
return DiagnosticFactory1DelegateProvider(Severity.INFO, positioningStrategy, P::class)
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleDataColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleColumnGroup
import org.jetbrains.kotlinx.dataframe.plugin.impl.SimpleFrameColumn
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.GroupBy
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.createPluginDataFrameSchema
import kotlin.math.abs

@OptIn(FirExtensionApiInternals::class)
Expand Down Expand Up @@ -321,14 +320,14 @@ class FunctionCallTransformer(
val receiverType = explicitReceiver?.resolvedType
val returnType = call.resolvedType
val scopeFunction = if (explicitReceiver != null) findLet() else findRun()
val originalSource = call.calleeReference.source

// original call is inserted later
call.transformCalleeReference(object : FirTransformer<Nothing?>() {
override fun <E : FirElement> transformElement(element: E, data: Nothing?): E {
return if (element is FirResolvedNamedReference) {
@Suppress("UNCHECKED_CAST")
buildResolvedNamedReference {
source = call.calleeReference.source
this.name = element.name
resolvedSymbol = originalSymbol
} as E
Expand Down Expand Up @@ -430,7 +429,8 @@ class FunctionCallTransformer(
}

val newCall1 = buildFunctionCall {
source = call.source
// source = call.source makes IDE navigate to `let` declaration
source = null
this.coneTypeOrNull = returnType
if (receiverType != null) {
typeArguments += buildTypeProjectionWithVariance {
Expand All @@ -455,7 +455,7 @@ class FunctionCallTransformer(
linkedMapOf(argument to scopeFunction.valueParameterSymbols[0].fir)
)
calleeReference = buildResolvedNamedReference {
source = call.calleeReference.source
source = originalSource
this.name = scopeFunction.name
resolvedSymbol = scopeFunction
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import kotlin.reflect.KType
import kotlin.reflect.KTypeProjection
import kotlin.reflect.KVariance

interface KotlinTypeFacade {
val session: FirSession
interface KotlinTypeFacade : SessionContext {
val resolutionPath: String? get() = null
val cache: FirCache<String, PluginDataFrameSchema, KotlinTypeFacade>
val schemasDirectory: String?
Expand Down Expand Up @@ -99,6 +98,14 @@ interface KotlinTypeFacade {
}
}

interface SessionContext {
val session: FirSession
}

fun SessionContext(session: FirSession) = object : SessionContext {
override val session: FirSession = session
}

private val List = "List".collectionsId()

private fun ConeKotlinType.isBuiltinType(classId: ClassId, isNullable: Boolean?): Boolean {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
@file:Suppress("INVISIBLE_REFERENCE", "CANNOT_OVERRIDE_INVISIBLE_MEMBER")

package org.jetbrains.kotlinx.dataframe.plugin.impl

import org.jetbrains.kotlin.fir.analysis.checkers.fullyExpandedClassId
import org.jetbrains.kotlin.fir.types.ConeKotlinType
import org.jetbrains.kotlin.fir.types.ConeNullability
import org.jetbrains.kotlin.fir.types.isNullable
import org.jetbrains.kotlin.fir.types.renderReadable
import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade
import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap
import org.jetbrains.kotlinx.dataframe.plugin.impl.api.TypeApproximation
Expand All @@ -18,6 +17,7 @@ data class PluginDataFrameSchema(
companion object {
val EMPTY = PluginDataFrameSchema(emptyList())
}

fun columns(): List<SimpleCol> {
return columns
}
Expand All @@ -32,16 +32,19 @@ fun PluginDataFrameSchema.add(name: String, type: ConeKotlinType, context: Kotli
}

private fun List<SimpleCol>.asString(indent: String = ""): String {
if (isEmpty()) return "$indent<empty compile time schema>"
return joinToString("\n") {
val col = when (it) {
is SimpleFrameColumn -> {
"${it.name}*\n" + it.columns().asString("$indent ")
"${it.name}: *\n" + it.columns().asString("$indent ")
}

is SimpleColumnGroup -> {
"${it.name}\n" + it.columns().asString("$indent ")
"${it.name}:\n" + it.columns().asString("$indent ")
}

is SimpleDataColumn -> {
"${it.name}: ${it.type}"
"${it.name}: ${it.type.type.renderReadable()}"
}
}
"$indent$col"
Expand Down Expand Up @@ -127,6 +130,7 @@ private fun KotlinTypeFacade.makeNullable(column: SimpleCol): SimpleCol {
is SimpleColumnGroup -> {
SimpleColumnGroup(column.name, column.columns().map { makeNullable(it) })
}

is SimpleFrameColumn -> column
is SimpleDataColumn -> SimpleDataColumn(column.name, column.type.changeNullability { true })
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.runIf
import org.jetbrains.kotlinx.dataframe.annotations.*
import org.jetbrains.kotlinx.dataframe.plugin.extensions.SessionContext
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.ColumnWithPathApproximation
import org.jetbrains.kotlinx.dataframe.plugin.impl.data.DataFrameCallableId
Expand Down Expand Up @@ -333,7 +334,7 @@ interface InterpretationErrorReporter {
}
}

fun KotlinTypeFacade.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): PluginDataFrameSchema {
fun SessionContext.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): PluginDataFrameSchema {
val schema = if (schemaTypeArg.isStarProjection) {
PluginDataFrameSchema.EMPTY
} else {
Expand All @@ -343,7 +344,7 @@ fun KotlinTypeFacade.pluginDataFrameSchema(schemaTypeArg: ConeTypeProjection): P
return schema
}

fun KotlinTypeFacade.pluginDataFrameSchema(coneClassLikeType: ConeClassLikeType): PluginDataFrameSchema {
fun SessionContext.pluginDataFrameSchema(coneClassLikeType: ConeClassLikeType): PluginDataFrameSchema {
val symbol = coneClassLikeType.toSymbol(session) as? FirRegularClassSymbol ?: return PluginDataFrameSchema.EMPTY
val declarationSymbols = if (symbol.isLocal && symbol.resolvedSuperTypes.firstOrNull() != session.builtinTypes.anyType.type) {
val rootSchemaSymbol = symbol.resolvedSuperTypes.first().toSymbol(session) as? FirRegularClassSymbol
Expand Down Expand Up @@ -407,7 +408,7 @@ private fun KotlinTypeFacade.columnWithPathApproximations(result: FirPropertyAcc
}
}

private fun KotlinTypeFacade.columnOf(it: FirPropertySymbol, mapping: Map<FirTypeParameterSymbol, ConeTypeProjection>): SimpleCol? {
private fun SessionContext.columnOf(it: FirPropertySymbol, mapping: Map<FirTypeParameterSymbol, ConeTypeProjection>): SimpleCol? {
val annotation = it.getAnnotationByClassId(Names.COLUMN_NAME_ANNOTATION, session)
val columnName = (annotation?.argumentMapping?.mapping?.get(Names.COLUMN_NAME_ARGUMENT) as? FirLiteralExpression)?.value as? String
val name = columnName ?: it.name.identifier
Expand Down Expand Up @@ -456,14 +457,14 @@ private fun KotlinTypeFacade.columnOf(it: FirPropertySymbol, mapping: Map<FirTyp
}
}

private fun KotlinTypeFacade.shouldBeConvertedToColumnGroup(it: FirPropertySymbol) =
private fun SessionContext.shouldBeConvertedToColumnGroup(it: FirPropertySymbol) =
isDataRow(it) ||
it.resolvedReturnType.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true

private fun isDataRow(it: FirPropertySymbol) =
it.resolvedReturnType.classId == Names.DATA_ROW_CLASS_ID

private fun KotlinTypeFacade.shouldBeConvertedToFrameColumn(it: FirPropertySymbol) =
private fun SessionContext.shouldBeConvertedToFrameColumn(it: FirPropertySymbol) =
isDataFrame(it) ||
(it.resolvedReturnType.classId == Names.LIST &&
it.resolvedReturnType.typeArguments[0].type?.toRegularClassSymbol(session)?.hasAnnotation(Names.DATA_SCHEMA_CLASS_ID, session) == true)
Expand Down
Loading
Loading