11package org.jetbrains.kotlinx.spark.api.compilerPlugin.ir
22
33import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
4+ import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver
5+ import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
6+ import org.jetbrains.kotlin.backend.common.lower.irThrow
7+ import org.jetbrains.kotlin.descriptors.Modality
48import org.jetbrains.kotlin.ir.IrElement
59import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
610import org.jetbrains.kotlin.ir.backend.js.utils.valueArguments
11+ import org.jetbrains.kotlin.ir.builders.declarations.addFunction
12+ import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
13+ import org.jetbrains.kotlin.ir.builders.irBlockBody
14+ import org.jetbrains.kotlin.ir.builders.irBranch
15+ import org.jetbrains.kotlin.ir.builders.irCall
16+ import org.jetbrains.kotlin.ir.builders.irElseBranch
17+ import org.jetbrains.kotlin.ir.builders.irEquals
18+ import org.jetbrains.kotlin.ir.builders.irGet
19+ import org.jetbrains.kotlin.ir.builders.irIs
20+ import org.jetbrains.kotlin.ir.builders.irReturn
21+ import org.jetbrains.kotlin.ir.builders.irWhen
722import org.jetbrains.kotlin.ir.declarations.IrClass
8- import org.jetbrains.kotlin.ir.declarations.IrDeclaration
9- import org.jetbrains.kotlin.ir.declarations.IrFile
10- import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
1123import org.jetbrains.kotlin.ir.declarations.IrProperty
12- import org.jetbrains.kotlin.ir.expressions.IrBlockBody
1324import org.jetbrains.kotlin.ir.expressions.IrConst
25+ import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
1426import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
1527import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
28+ import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
29+ import org.jetbrains.kotlin.ir.types.classFqName
30+ import org.jetbrains.kotlin.ir.types.classOrNull
1631import org.jetbrains.kotlin.ir.types.defaultType
32+ import org.jetbrains.kotlin.ir.types.superTypes
1733import org.jetbrains.kotlin.ir.util.constructors
34+ import org.jetbrains.kotlin.ir.util.defaultType
35+ import org.jetbrains.kotlin.ir.util.functions
1836import org.jetbrains.kotlin.ir.util.hasAnnotation
1937import org.jetbrains.kotlin.ir.util.isAnnotationWithEqualFqName
2038import org.jetbrains.kotlin.ir.util.parentAsClass
2139import org.jetbrains.kotlin.ir.util.primaryConstructor
40+ import org.jetbrains.kotlin.ir.util.properties
41+ import org.jetbrains.kotlin.ir.util.toIrConst
2242import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
2343import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
2444import org.jetbrains.kotlin.name.ClassId
2545import org.jetbrains.kotlin.name.FqName
46+ import org.jetbrains.kotlin.name.Name
47+ import org.jetbrains.kotlin.name.SpecialNames
2648
2749class DataClassPropertyAnnotationGenerator (
2850 private val pluginContext : IrPluginContext ,
2951 private val sparkifyAnnotationFqNames : List <String >,
3052 private val columnNameAnnotationFqNames : List <String >,
53+ private val productFqNames : List <String >,
3154) : IrElementVisitorVoid {
3255
3356 init {
@@ -51,11 +74,35 @@ class DataClassPropertyAnnotationGenerator(
5174 }
5275 }
5376
77+ /* *
78+ * Converts
79+ * ```kt
80+ * @Sparkify
81+ * data class User(
82+ * val name: String = "John Doe",
83+ * @get:JvmName("ignored") val age: Int = 25,
84+ * @ColumnName("a") val test: Double = 1.0,
85+ * @get:ColumnName("b") val test2: Double = 2.0,
86+ * )
87+ * ```
88+ * to
89+ * ```kt
90+ * @Sparkify
91+ * data class User(
92+ * @get:JvmName("name") val name: String = "John Doe",
93+ * @get:JvmName("age") val age: Int = 25,
94+ * @get:JvmName("a") @ColumnName("a") val test: Double = 1.0,
95+ * @get:JvmName("b") @get:ColumnName("b") val test2: Double = 2.0,
96+ * )
97+ * ```
98+ */
5499 override fun visitProperty (declaration : IrProperty ) {
55100 val origin = declaration.parent as ? IrClass ? : return super .visitProperty(declaration)
56101 if (sparkifyAnnotationFqNames.none { origin.hasAnnotation(FqName (it)) })
57102 return super .visitProperty(declaration)
58103
104+ if (! origin.isData) return super .visitProperty(declaration)
105+
59106 // must be in primary constructor
60107 val constructorParams = declaration.parentAsClass.primaryConstructor?.valueParameters
61108 ? : return super .visitProperty(declaration)
@@ -96,7 +143,7 @@ class DataClassPropertyAnnotationGenerator(
96143 .filterNot { it.isAnnotationWithEqualFqName(jvmNameFqName) }
97144
98145 // create a new JvmName annotation with newName
99- val jvmNameClassId = ClassId ( jvmNameFqName.parent(), jvmNameFqName.shortName() )
146+ val jvmNameClassId = jvmNameFqName.toClassId( )
100147 val jvmName = pluginContext.referenceClass(jvmNameClassId)!!
101148 val jvmNameConstructor = jvmName
102149 .constructors
@@ -118,4 +165,210 @@ class DataClassPropertyAnnotationGenerator(
118165 getter.annotations + = jvmNameAnnotationCall
119166 println (" Added @get:JvmName(\" $newName \" ) annotation to property ${origin.name} .${declaration.name} " )
120167 }
168+
169+ private fun FqName.toClassId (): ClassId = ClassId (packageFqName = parent(), topLevelName = shortName())
170+
171+ /* *
172+ * Converts
173+ * ```kt
174+ * @Sparkify
175+ * data class User(
176+ * val name: String = "John Doe",
177+ * val age: Int = 25,
178+ * @ColumnName("a") val test: Double = 1.0,
179+ * @get:ColumnName("b") val test2: Double = 2.0,
180+ * )
181+ * ```
182+ * to
183+ * ```kt
184+ * @Sparkify
185+ * data class User(
186+ * val name: String = "John Doe",
187+ * val age: Int = 25,
188+ * @ColumnName("a") val test: Double = 1.0,
189+ * @get:ColumnName("b") val test2: Double = 2.0,
190+ * ): scala.Product {
191+ * override fun canEqual(that: Any?): Boolean = that is User
192+ * override fun productElement(n: Int): Any = when (n) {
193+ * 0 -> name
194+ * 1 -> age
195+ * 2 -> test
196+ * else -> throw IndexOutOfBoundsException(n.toString())
197+ * }
198+ * override fun productArity(): Int = 4
199+ * }
200+ * ```
201+ */
202+ @OptIn(UnsafeDuringIrConstructionAPI ::class )
203+ override fun visitClass (declaration : IrClass ) {
204+ if (sparkifyAnnotationFqNames.none { declaration.hasAnnotation(FqName (it)) })
205+ return super .visitClass(declaration)
206+
207+ if (! declaration.isData) return super .visitClass(declaration)
208+
209+ // add superclass
210+ val scalaProductClass = productFqNames.firstNotNullOfOrNull {
211+ val classId = ClassId .topLevel(FqName (it))
212+ // ClassId(
213+ // packageFqName = FqName("scala"),
214+ // topLevelName = Name.identifier("Product"),
215+ // )
216+ pluginContext.referenceClass(classId)
217+ }!!
218+
219+ declaration.superTypes + = scalaProductClass.defaultType
220+
221+ // finding the constructor params
222+ val constructorParams = declaration.primaryConstructor?.valueParameters
223+ ? : return super .visitClass(declaration)
224+
225+ // finding properties
226+ val props = declaration.properties
227+
228+ // getting the properties that are in the constructor
229+ val properties = constructorParams.mapNotNull { param ->
230+ props.firstOrNull { it.name == param.name }
231+ }
232+
233+ // finding supertype Equals
234+ val superEqualsInterface = scalaProductClass.superTypes()
235+ .first { it.classFqName?.shortName()?.asString()?.contains(" Equals" ) == true }
236+ .classOrNull ? : return super .visitClass(declaration)
237+
238+ // add canEqual
239+ val superCanEqualFunction = superEqualsInterface.functions.first {
240+ it.owner.name.asString() == " canEqual" &&
241+ it.owner.valueParameters.size == 1 &&
242+ it.owner.valueParameters.first().type == pluginContext.irBuiltIns.anyNType
243+ }
244+
245+ val canEqualFunction = declaration.addFunction(
246+ name = " canEqual" ,
247+ returnType = pluginContext.irBuiltIns.booleanType,
248+ modality = Modality .OPEN ,
249+ )
250+ with (canEqualFunction) {
251+ overriddenSymbols = listOf (superCanEqualFunction)
252+ parent = declaration
253+
254+ // add implicit $this parameter
255+ addDispatchReceiver {
256+ name = SpecialNames .THIS
257+ type = declaration.defaultType
258+ }
259+
260+ // add that parameter
261+ val that = addValueParameter(
262+ name = Name .identifier(" that" ),
263+ type = pluginContext.irBuiltIns.anyNType,
264+ )
265+
266+ // add body
267+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
268+ val call = irIs(argument = irGet(that), type = declaration.defaultType)
269+ + irReturn(call)
270+ }
271+ }
272+
273+ // add productArity
274+ val superProductArityFunction = scalaProductClass.functions.first {
275+ it.owner.name.asString() == " productArity" &&
276+ it.owner.valueParameters.isEmpty()
277+ }
278+
279+ val productArityFunction = declaration.addFunction(
280+ name = " productArity" ,
281+ returnType = pluginContext.irBuiltIns.intType,
282+ modality = Modality .OPEN ,
283+ )
284+ with (productArityFunction) {
285+ overriddenSymbols = listOf (superProductArityFunction)
286+ parent = declaration
287+
288+ // add implicit $this parameter
289+ addDispatchReceiver {
290+ name = SpecialNames .THIS
291+ type = declaration.defaultType
292+ }
293+
294+ // add body
295+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
296+ val const = properties.size.toIrConst(pluginContext.irBuiltIns.intType)
297+ + irReturn(const)
298+ }
299+ }
300+
301+ // add productElement
302+ val superProductElementFunction = scalaProductClass.functions.first {
303+ it.owner.name.asString() == " productElement" &&
304+ it.owner.valueParameters.size == 1 &&
305+ it.owner.valueParameters.first().type == pluginContext.irBuiltIns.intType
306+ }
307+
308+ val productElementFunction = declaration.addFunction(
309+ name = " productElement" ,
310+ returnType = pluginContext.irBuiltIns.anyNType,
311+ modality = Modality .OPEN ,
312+ )
313+ with (productElementFunction) {
314+ overriddenSymbols = listOf (superProductElementFunction)
315+ parent = declaration
316+
317+ // add implicit $this parameter
318+ val `this ` = addDispatchReceiver {
319+ name = SpecialNames .THIS
320+ type = declaration.defaultType
321+ }
322+
323+ // add n parameter
324+ val n = addValueParameter(
325+ name = Name .identifier(" n" ),
326+ type = pluginContext.irBuiltIns.intType,
327+ )
328+
329+ // add body
330+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
331+ val whenBranches = buildList {
332+ for ((i, prop) in properties.withIndex()) {
333+ val condition = irEquals(
334+ arg1 = irGet(n),
335+ arg2 = i.toIrConst(pluginContext.irBuiltIns.intType),
336+ )
337+ val call = irCall(prop.getter!! )
338+ with (call) {
339+ origin = IrStatementOrigin .GET_PROPERTY
340+ dispatchReceiver = irGet(`this `)
341+ }
342+
343+ val branch = irBranch(
344+ condition = condition,
345+ result = call
346+ )
347+ add(branch)
348+ }
349+
350+ val ioobClass = pluginContext.referenceClass(
351+ FqName (" java.lang.IndexOutOfBoundsException" ).toClassId()
352+ )!!
353+ val ioobConstructor = ioobClass.constructors.first { it.owner.valueParameters.isEmpty() }
354+ val throwCall = irThrow(
355+ IrConstructorCallImpl .fromSymbolOwner(
356+ ioobClass.defaultType,
357+ ioobConstructor
358+ )
359+ )
360+ val elseBranch = irElseBranch(throwCall)
361+ add(elseBranch)
362+ }
363+ val whenBlock = irWhen(pluginContext.irBuiltIns.anyNType, whenBranches)
364+ with (whenBlock) {
365+ origin = IrStatementOrigin .IF
366+ }
367+ + irReturn(whenBlock)
368+ }
369+ }
370+
371+ // pass down to the properties
372+ declaration.acceptChildrenVoid(this )
373+ }
121374}
0 commit comments