Skip to content

Commit 4c17859

Browse files
committed
added conversion for @sparkify'ed classes to scala.Product with tests.
1 parent df021c0 commit 4c17859

File tree

11 files changed

+1004
-6
lines changed

11 files changed

+1004
-6
lines changed

build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ subprojects {
152152
buildConfigField("defaultSparkifyFqName", defaultSparkifyFqName)
153153
buildConfigField("defaultColumnNameFqName", defaultColumnNameFqName)
154154
buildConfigField("projectRoot", projectRoot)
155+
156+
buildConfigField("scalaVersion", Versions.scala)
157+
buildConfigField("sparkVersion", Versions.spark)
155158
}
156159
}
157160
}

compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/SparkifyCompilerPluginRegistrar.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ open class SparkifyCompilerPluginRegistrar: CompilerPluginRegistrar() {
2323
val columnNameAnnotationFqNames = configuration.get(KEY_COLUMN_NAME_ANNOTATION_FQ_NAMES)
2424
?: listOf(Artifacts.defaultColumnNameFqName)
2525

26+
val productFqNames = // TODO: get from configuration
27+
listOf("scala.Product")
28+
2629
IrGenerationExtension.registerExtension(
2730
SparkifyIrGenerationExtension(
2831
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
2932
columnNameAnnotationFqNames = columnNameAnnotationFqNames,
33+
productFqNames = productFqNames,
3034
)
3135
)
3236
}

compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/ir/DataClassPropertyAnnotationGenerator.kt

Lines changed: 258 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,56 @@
11
package org.jetbrains.kotlinx.spark.api.compilerPlugin.ir
22

33
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
4+
import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver
5+
import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
6+
import org.jetbrains.kotlin.backend.common.lower.irThrow
7+
import org.jetbrains.kotlin.descriptors.Modality
48
import org.jetbrains.kotlin.ir.IrElement
59
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
610
import org.jetbrains.kotlin.ir.backend.js.utils.valueArguments
11+
import org.jetbrains.kotlin.ir.builders.declarations.addFunction
12+
import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
13+
import org.jetbrains.kotlin.ir.builders.irBlockBody
14+
import org.jetbrains.kotlin.ir.builders.irBranch
15+
import org.jetbrains.kotlin.ir.builders.irCall
16+
import org.jetbrains.kotlin.ir.builders.irElseBranch
17+
import org.jetbrains.kotlin.ir.builders.irEquals
18+
import org.jetbrains.kotlin.ir.builders.irGet
19+
import org.jetbrains.kotlin.ir.builders.irIs
20+
import org.jetbrains.kotlin.ir.builders.irReturn
21+
import org.jetbrains.kotlin.ir.builders.irWhen
722
import org.jetbrains.kotlin.ir.declarations.IrClass
8-
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
9-
import org.jetbrains.kotlin.ir.declarations.IrFile
10-
import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
1123
import org.jetbrains.kotlin.ir.declarations.IrProperty
12-
import org.jetbrains.kotlin.ir.expressions.IrBlockBody
1324
import org.jetbrains.kotlin.ir.expressions.IrConst
25+
import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
1426
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
1527
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
28+
import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
29+
import org.jetbrains.kotlin.ir.types.classFqName
30+
import org.jetbrains.kotlin.ir.types.classOrNull
1631
import org.jetbrains.kotlin.ir.types.defaultType
32+
import org.jetbrains.kotlin.ir.types.superTypes
1733
import org.jetbrains.kotlin.ir.util.constructors
34+
import org.jetbrains.kotlin.ir.util.defaultType
35+
import org.jetbrains.kotlin.ir.util.functions
1836
import org.jetbrains.kotlin.ir.util.hasAnnotation
1937
import org.jetbrains.kotlin.ir.util.isAnnotationWithEqualFqName
2038
import org.jetbrains.kotlin.ir.util.parentAsClass
2139
import org.jetbrains.kotlin.ir.util.primaryConstructor
40+
import org.jetbrains.kotlin.ir.util.properties
41+
import org.jetbrains.kotlin.ir.util.toIrConst
2242
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
2343
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
2444
import org.jetbrains.kotlin.name.ClassId
2545
import org.jetbrains.kotlin.name.FqName
46+
import org.jetbrains.kotlin.name.Name
47+
import org.jetbrains.kotlin.name.SpecialNames
2648

2749
class DataClassPropertyAnnotationGenerator(
2850
private val pluginContext: IrPluginContext,
2951
private val sparkifyAnnotationFqNames: List<String>,
3052
private val columnNameAnnotationFqNames: List<String>,
53+
private val productFqNames: List<String>,
3154
) : IrElementVisitorVoid {
3255

3356
init {
@@ -51,11 +74,35 @@ class DataClassPropertyAnnotationGenerator(
5174
}
5275
}
5376

77+
/**
78+
* Converts
79+
* ```kt
80+
* @Sparkify
81+
* data class User(
82+
* val name: String = "John Doe",
83+
* @get:JvmName("ignored") val age: Int = 25,
84+
* @ColumnName("a") val test: Double = 1.0,
85+
* @get:ColumnName("b") val test2: Double = 2.0,
86+
* )
87+
* ```
88+
* to
89+
* ```kt
90+
* @Sparkify
91+
* data class User(
92+
* @get:JvmName("name") val name: String = "John Doe",
93+
* @get:JvmName("age") val age: Int = 25,
94+
* @get:JvmName("a") @ColumnName("a") val test: Double = 1.0,
95+
* @get:JvmName("b") @get:ColumnName("b") val test2: Double = 2.0,
96+
* )
97+
* ```
98+
*/
5499
override fun visitProperty(declaration: IrProperty) {
55100
val origin = declaration.parent as? IrClass ?: return super.visitProperty(declaration)
56101
if (sparkifyAnnotationFqNames.none { origin.hasAnnotation(FqName(it)) })
57102
return super.visitProperty(declaration)
58103

104+
if (!origin.isData) return super.visitProperty(declaration)
105+
59106
// must be in primary constructor
60107
val constructorParams = declaration.parentAsClass.primaryConstructor?.valueParameters
61108
?: return super.visitProperty(declaration)
@@ -96,7 +143,7 @@ class DataClassPropertyAnnotationGenerator(
96143
.filterNot { it.isAnnotationWithEqualFqName(jvmNameFqName) }
97144

98145
// create a new JvmName annotation with newName
99-
val jvmNameClassId = ClassId(jvmNameFqName.parent(), jvmNameFqName.shortName())
146+
val jvmNameClassId = jvmNameFqName.toClassId()
100147
val jvmName = pluginContext.referenceClass(jvmNameClassId)!!
101148
val jvmNameConstructor = jvmName
102149
.constructors
@@ -118,4 +165,210 @@ class DataClassPropertyAnnotationGenerator(
118165
getter.annotations += jvmNameAnnotationCall
119166
println("Added @get:JvmName(\"$newName\") annotation to property ${origin.name}.${declaration.name}")
120167
}
168+
169+
private fun FqName.toClassId(): ClassId = ClassId(packageFqName = parent(), topLevelName = shortName())
170+
171+
/**
172+
* Converts
173+
* ```kt
174+
* @Sparkify
175+
* data class User(
176+
* val name: String = "John Doe",
177+
* val age: Int = 25,
178+
* @ColumnName("a") val test: Double = 1.0,
179+
* @get:ColumnName("b") val test2: Double = 2.0,
180+
* )
181+
* ```
182+
* to
183+
* ```kt
184+
* @Sparkify
185+
* data class User(
186+
* val name: String = "John Doe",
187+
* val age: Int = 25,
188+
* @ColumnName("a") val test: Double = 1.0,
189+
* @get:ColumnName("b") val test2: Double = 2.0,
190+
* ): scala.Product {
191+
* override fun canEqual(that: Any?): Boolean = that is User
192+
* override fun productElement(n: Int): Any = when (n) {
193+
* 0 -> name
194+
* 1 -> age
195+
* 2 -> test
196+
* else -> throw IndexOutOfBoundsException(n.toString())
197+
* }
198+
* override fun productArity(): Int = 4
199+
* }
200+
* ```
201+
*/
202+
@OptIn(UnsafeDuringIrConstructionAPI::class)
203+
override fun visitClass(declaration: IrClass) {
204+
if (sparkifyAnnotationFqNames.none { declaration.hasAnnotation(FqName(it)) })
205+
return super.visitClass(declaration)
206+
207+
if (!declaration.isData) return super.visitClass(declaration)
208+
209+
// add superclass
210+
val scalaProductClass = productFqNames.firstNotNullOfOrNull {
211+
val classId = ClassId.topLevel(FqName(it))
212+
// ClassId(
213+
// packageFqName = FqName("scala"),
214+
// topLevelName = Name.identifier("Product"),
215+
// )
216+
pluginContext.referenceClass(classId)
217+
}!!
218+
219+
declaration.superTypes += scalaProductClass.defaultType
220+
221+
// finding the constructor params
222+
val constructorParams = declaration.primaryConstructor?.valueParameters
223+
?: return super.visitClass(declaration)
224+
225+
// finding properties
226+
val props = declaration.properties
227+
228+
// getting the properties that are in the constructor
229+
val properties = constructorParams.mapNotNull { param ->
230+
props.firstOrNull { it.name == param.name }
231+
}
232+
233+
// finding supertype Equals
234+
val superEqualsInterface = scalaProductClass.superTypes()
235+
.first { it.classFqName?.shortName()?.asString()?.contains("Equals") == true }
236+
.classOrNull ?: return super.visitClass(declaration)
237+
238+
// add canEqual
239+
val superCanEqualFunction = superEqualsInterface.functions.first {
240+
it.owner.name.asString() == "canEqual" &&
241+
it.owner.valueParameters.size == 1 &&
242+
it.owner.valueParameters.first().type == pluginContext.irBuiltIns.anyNType
243+
}
244+
245+
val canEqualFunction = declaration.addFunction(
246+
name = "canEqual",
247+
returnType = pluginContext.irBuiltIns.booleanType,
248+
modality = Modality.OPEN,
249+
)
250+
with(canEqualFunction) {
251+
overriddenSymbols = listOf(superCanEqualFunction)
252+
parent = declaration
253+
254+
// add implicit $this parameter
255+
addDispatchReceiver {
256+
name = SpecialNames.THIS
257+
type = declaration.defaultType
258+
}
259+
260+
// add that parameter
261+
val that = addValueParameter(
262+
name = Name.identifier("that"),
263+
type = pluginContext.irBuiltIns.anyNType,
264+
)
265+
266+
// add body
267+
body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
268+
val call = irIs(argument = irGet(that), type = declaration.defaultType)
269+
+irReturn(call)
270+
}
271+
}
272+
273+
// add productArity
274+
val superProductArityFunction = scalaProductClass.functions.first {
275+
it.owner.name.asString() == "productArity" &&
276+
it.owner.valueParameters.isEmpty()
277+
}
278+
279+
val productArityFunction = declaration.addFunction(
280+
name = "productArity",
281+
returnType = pluginContext.irBuiltIns.intType,
282+
modality = Modality.OPEN,
283+
)
284+
with(productArityFunction) {
285+
overriddenSymbols = listOf(superProductArityFunction)
286+
parent = declaration
287+
288+
// add implicit $this parameter
289+
addDispatchReceiver {
290+
name = SpecialNames.THIS
291+
type = declaration.defaultType
292+
}
293+
294+
// add body
295+
body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
296+
val const = properties.size.toIrConst(pluginContext.irBuiltIns.intType)
297+
+irReturn(const)
298+
}
299+
}
300+
301+
// add productElement
302+
val superProductElementFunction = scalaProductClass.functions.first {
303+
it.owner.name.asString() == "productElement" &&
304+
it.owner.valueParameters.size == 1 &&
305+
it.owner.valueParameters.first().type == pluginContext.irBuiltIns.intType
306+
}
307+
308+
val productElementFunction = declaration.addFunction(
309+
name = "productElement",
310+
returnType = pluginContext.irBuiltIns.anyNType,
311+
modality = Modality.OPEN,
312+
)
313+
with(productElementFunction) {
314+
overriddenSymbols = listOf(superProductElementFunction)
315+
parent = declaration
316+
317+
// add implicit $this parameter
318+
val `this` = addDispatchReceiver {
319+
name = SpecialNames.THIS
320+
type = declaration.defaultType
321+
}
322+
323+
// add n parameter
324+
val n = addValueParameter(
325+
name = Name.identifier("n"),
326+
type = pluginContext.irBuiltIns.intType,
327+
)
328+
329+
// add body
330+
body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
331+
val whenBranches = buildList {
332+
for ((i, prop) in properties.withIndex()) {
333+
val condition = irEquals(
334+
arg1 = irGet(n),
335+
arg2 = i.toIrConst(pluginContext.irBuiltIns.intType),
336+
)
337+
val call = irCall(prop.getter!!)
338+
with(call) {
339+
origin = IrStatementOrigin.GET_PROPERTY
340+
dispatchReceiver = irGet(`this`)
341+
}
342+
343+
val branch = irBranch(
344+
condition = condition,
345+
result = call
346+
)
347+
add(branch)
348+
}
349+
350+
val ioobClass = pluginContext.referenceClass(
351+
FqName("java.lang.IndexOutOfBoundsException").toClassId()
352+
)!!
353+
val ioobConstructor = ioobClass.constructors.first { it.owner.valueParameters.isEmpty() }
354+
val throwCall = irThrow(
355+
IrConstructorCallImpl.fromSymbolOwner(
356+
ioobClass.defaultType,
357+
ioobConstructor
358+
)
359+
)
360+
val elseBranch = irElseBranch(throwCall)
361+
add(elseBranch)
362+
}
363+
val whenBlock = irWhen(pluginContext.irBuiltIns.anyNType, whenBranches)
364+
with(whenBlock) {
365+
origin = IrStatementOrigin.IF
366+
}
367+
+irReturn(whenBlock)
368+
}
369+
}
370+
371+
// pass down to the properties
372+
declaration.acceptChildrenVoid(this)
373+
}
121374
}

compiler-plugin/src/main/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/ir/SparkifyIrGenerationExtension.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@ import org.jetbrains.kotlinx.spark.api.compilerPlugin.ir.DataClassPropertyAnnota
99
class SparkifyIrGenerationExtension(
1010
private val sparkifyAnnotationFqNames: List<String>,
1111
private val columnNameAnnotationFqNames: List<String>,
12+
private val productFqNames: List<String>,
1213
) : IrGenerationExtension {
1314
override fun generate(moduleFragment: IrModuleFragment, pluginContext: IrPluginContext) {
1415
val visitors = listOf(
1516
DataClassPropertyAnnotationGenerator(
1617
pluginContext = pluginContext,
1718
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
1819
columnNameAnnotationFqNames = columnNameAnnotationFqNames,
20+
productFqNames = productFqNames,
1921
),
2022
)
2123
for (visitor in visitors) {

compiler-plugin/src/test-gen/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/runners/BoxTestGenerated.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ public void testDataClassInFunctionTest() {
2727
runTest("/mnt/data/Projects/kotlin-spark-api/compiler-plugin/src/test/resources/testData/box/dataClassInFunctionTest.kt");
2828
}
2929

30+
@Test
31+
@TestMetadata("dataClassIsProductTest.kt")
32+
public void testDataClassIsProductTest() {
33+
runTest("/mnt/data/Projects/kotlin-spark-api/compiler-plugin/src/test/resources/testData/box/dataClassIsProductTest.kt");
34+
}
35+
3036
@Test
3137
@TestMetadata("dataClassTest.kt")
3238
public void testDataClassTest() {

compiler-plugin/src/test/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/runners/BaseTestRunner.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ fun TestConfigurationBuilder.commonFirWithPluginFrontendConfiguration() {
3535
}
3636

3737
useConfigurators(
38-
::ExtensionRegistrarConfigurator
38+
::ExtensionRegistrarConfigurator,
3939
)
4040
}

compiler-plugin/src/test/kotlin/org/jetbrains/kotlinx/spark/api/compilerPlugin/services/ExtensionRegistrarConfigurator.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ class ExtensionRegistrarConfigurator(testServices: TestServices) : EnvironmentCo
1515
) {
1616
val sparkifyAnnotationFqNames = listOf("foo.bar.Sparkify")
1717
val columnNameAnnotationFqNames = listOf("foo.bar.ColumnName")
18+
val productFqNames = listOf("foo.bar.Product")
1819
IrGenerationExtension.registerExtension(
1920
SparkifyIrGenerationExtension(
2021
sparkifyAnnotationFqNames = sparkifyAnnotationFqNames,
2122
columnNameAnnotationFqNames = columnNameAnnotationFqNames,
23+
productFqNames = productFqNames,
2224
)
2325
)
2426
}

0 commit comments

Comments
 (0)