1
1
package org.jetbrains.kotlinx.spark.api.compilerPlugin.ir
2
2
3
3
import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
4
+ import org.jetbrains.kotlin.backend.common.ir.addDispatchReceiver
5
+ import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
6
+ import org.jetbrains.kotlin.backend.common.lower.irThrow
7
+ import org.jetbrains.kotlin.descriptors.Modality
4
8
import org.jetbrains.kotlin.ir.IrElement
5
9
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
6
10
import org.jetbrains.kotlin.ir.backend.js.utils.valueArguments
11
+ import org.jetbrains.kotlin.ir.builders.declarations.addFunction
12
+ import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
13
+ import org.jetbrains.kotlin.ir.builders.irBlockBody
14
+ import org.jetbrains.kotlin.ir.builders.irBranch
15
+ import org.jetbrains.kotlin.ir.builders.irCall
16
+ import org.jetbrains.kotlin.ir.builders.irElseBranch
17
+ import org.jetbrains.kotlin.ir.builders.irEquals
18
+ import org.jetbrains.kotlin.ir.builders.irGet
19
+ import org.jetbrains.kotlin.ir.builders.irIs
20
+ import org.jetbrains.kotlin.ir.builders.irReturn
21
+ import org.jetbrains.kotlin.ir.builders.irWhen
7
22
import org.jetbrains.kotlin.ir.declarations.IrClass
8
- import org.jetbrains.kotlin.ir.declarations.IrDeclaration
9
- import org.jetbrains.kotlin.ir.declarations.IrFile
10
- import org.jetbrains.kotlin.ir.declarations.IrModuleFragment
11
23
import org.jetbrains.kotlin.ir.declarations.IrProperty
12
- import org.jetbrains.kotlin.ir.expressions.IrBlockBody
13
24
import org.jetbrains.kotlin.ir.expressions.IrConst
25
+ import org.jetbrains.kotlin.ir.expressions.IrStatementOrigin
14
26
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
15
27
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
28
+ import org.jetbrains.kotlin.ir.symbols.UnsafeDuringIrConstructionAPI
29
+ import org.jetbrains.kotlin.ir.types.classFqName
30
+ import org.jetbrains.kotlin.ir.types.classOrNull
16
31
import org.jetbrains.kotlin.ir.types.defaultType
32
+ import org.jetbrains.kotlin.ir.types.superTypes
17
33
import org.jetbrains.kotlin.ir.util.constructors
34
+ import org.jetbrains.kotlin.ir.util.defaultType
35
+ import org.jetbrains.kotlin.ir.util.functions
18
36
import org.jetbrains.kotlin.ir.util.hasAnnotation
19
37
import org.jetbrains.kotlin.ir.util.isAnnotationWithEqualFqName
20
38
import org.jetbrains.kotlin.ir.util.parentAsClass
21
39
import org.jetbrains.kotlin.ir.util.primaryConstructor
40
+ import org.jetbrains.kotlin.ir.util.properties
41
+ import org.jetbrains.kotlin.ir.util.toIrConst
22
42
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
23
43
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
24
44
import org.jetbrains.kotlin.name.ClassId
25
45
import org.jetbrains.kotlin.name.FqName
46
+ import org.jetbrains.kotlin.name.Name
47
+ import org.jetbrains.kotlin.name.SpecialNames
26
48
27
49
class DataClassPropertyAnnotationGenerator (
28
50
private val pluginContext : IrPluginContext ,
29
51
private val sparkifyAnnotationFqNames : List <String >,
30
52
private val columnNameAnnotationFqNames : List <String >,
53
+ private val productFqNames : List <String >,
31
54
) : IrElementVisitorVoid {
32
55
33
56
init {
@@ -51,11 +74,35 @@ class DataClassPropertyAnnotationGenerator(
51
74
}
52
75
}
53
76
77
+ /* *
78
+ * Converts
79
+ * ```kt
80
+ * @Sparkify
81
+ * data class User(
82
+ * val name: String = "John Doe",
83
+ * @get:JvmName("ignored") val age: Int = 25,
84
+ * @ColumnName("a") val test: Double = 1.0,
85
+ * @get:ColumnName("b") val test2: Double = 2.0,
86
+ * )
87
+ * ```
88
+ * to
89
+ * ```kt
90
+ * @Sparkify
91
+ * data class User(
92
+ * @get:JvmName("name") val name: String = "John Doe",
93
+ * @get:JvmName("age") val age: Int = 25,
94
+ * @get:JvmName("a") @ColumnName("a") val test: Double = 1.0,
95
+ * @get:JvmName("b") @get:ColumnName("b") val test2: Double = 2.0,
96
+ * )
97
+ * ```
98
+ */
54
99
override fun visitProperty (declaration : IrProperty ) {
55
100
val origin = declaration.parent as ? IrClass ? : return super .visitProperty(declaration)
56
101
if (sparkifyAnnotationFqNames.none { origin.hasAnnotation(FqName (it)) })
57
102
return super .visitProperty(declaration)
58
103
104
+ if (! origin.isData) return super .visitProperty(declaration)
105
+
59
106
// must be in primary constructor
60
107
val constructorParams = declaration.parentAsClass.primaryConstructor?.valueParameters
61
108
? : return super .visitProperty(declaration)
@@ -96,7 +143,7 @@ class DataClassPropertyAnnotationGenerator(
96
143
.filterNot { it.isAnnotationWithEqualFqName(jvmNameFqName) }
97
144
98
145
// create a new JvmName annotation with newName
99
- val jvmNameClassId = ClassId ( jvmNameFqName.parent(), jvmNameFqName.shortName() )
146
+ val jvmNameClassId = jvmNameFqName.toClassId( )
100
147
val jvmName = pluginContext.referenceClass(jvmNameClassId)!!
101
148
val jvmNameConstructor = jvmName
102
149
.constructors
@@ -118,4 +165,210 @@ class DataClassPropertyAnnotationGenerator(
118
165
getter.annotations + = jvmNameAnnotationCall
119
166
println (" Added @get:JvmName(\" $newName \" ) annotation to property ${origin.name} .${declaration.name} " )
120
167
}
168
+
169
+ private fun FqName.toClassId (): ClassId = ClassId (packageFqName = parent(), topLevelName = shortName())
170
+
171
+ /* *
172
+ * Converts
173
+ * ```kt
174
+ * @Sparkify
175
+ * data class User(
176
+ * val name: String = "John Doe",
177
+ * val age: Int = 25,
178
+ * @ColumnName("a") val test: Double = 1.0,
179
+ * @get:ColumnName("b") val test2: Double = 2.0,
180
+ * )
181
+ * ```
182
+ * to
183
+ * ```kt
184
+ * @Sparkify
185
+ * data class User(
186
+ * val name: String = "John Doe",
187
+ * val age: Int = 25,
188
+ * @ColumnName("a") val test: Double = 1.0,
189
+ * @get:ColumnName("b") val test2: Double = 2.0,
190
+ * ): scala.Product {
191
+ * override fun canEqual(that: Any?): Boolean = that is User
192
+ * override fun productElement(n: Int): Any = when (n) {
193
+ * 0 -> name
194
+ * 1 -> age
195
+ * 2 -> test
196
+ * else -> throw IndexOutOfBoundsException(n.toString())
197
+ * }
198
+ * override fun productArity(): Int = 4
199
+ * }
200
+ * ```
201
+ */
202
+ @OptIn(UnsafeDuringIrConstructionAPI ::class )
203
+ override fun visitClass (declaration : IrClass ) {
204
+ if (sparkifyAnnotationFqNames.none { declaration.hasAnnotation(FqName (it)) })
205
+ return super .visitClass(declaration)
206
+
207
+ if (! declaration.isData) return super .visitClass(declaration)
208
+
209
+ // add superclass
210
+ val scalaProductClass = productFqNames.firstNotNullOfOrNull {
211
+ val classId = ClassId .topLevel(FqName (it))
212
+ // ClassId(
213
+ // packageFqName = FqName("scala"),
214
+ // topLevelName = Name.identifier("Product"),
215
+ // )
216
+ pluginContext.referenceClass(classId)
217
+ }!!
218
+
219
+ declaration.superTypes + = scalaProductClass.defaultType
220
+
221
+ // finding the constructor params
222
+ val constructorParams = declaration.primaryConstructor?.valueParameters
223
+ ? : return super .visitClass(declaration)
224
+
225
+ // finding properties
226
+ val props = declaration.properties
227
+
228
+ // getting the properties that are in the constructor
229
+ val properties = constructorParams.mapNotNull { param ->
230
+ props.firstOrNull { it.name == param.name }
231
+ }
232
+
233
+ // finding supertype Equals
234
+ val superEqualsInterface = scalaProductClass.superTypes()
235
+ .first { it.classFqName?.shortName()?.asString()?.contains(" Equals" ) == true }
236
+ .classOrNull ? : return super .visitClass(declaration)
237
+
238
+ // add canEqual
239
+ val superCanEqualFunction = superEqualsInterface.functions.first {
240
+ it.owner.name.asString() == " canEqual" &&
241
+ it.owner.valueParameters.size == 1 &&
242
+ it.owner.valueParameters.first().type == pluginContext.irBuiltIns.anyNType
243
+ }
244
+
245
+ val canEqualFunction = declaration.addFunction(
246
+ name = " canEqual" ,
247
+ returnType = pluginContext.irBuiltIns.booleanType,
248
+ modality = Modality .OPEN ,
249
+ )
250
+ with (canEqualFunction) {
251
+ overriddenSymbols = listOf (superCanEqualFunction)
252
+ parent = declaration
253
+
254
+ // add implicit $this parameter
255
+ addDispatchReceiver {
256
+ name = SpecialNames .THIS
257
+ type = declaration.defaultType
258
+ }
259
+
260
+ // add that parameter
261
+ val that = addValueParameter(
262
+ name = Name .identifier(" that" ),
263
+ type = pluginContext.irBuiltIns.anyNType,
264
+ )
265
+
266
+ // add body
267
+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
268
+ val call = irIs(argument = irGet(that), type = declaration.defaultType)
269
+ + irReturn(call)
270
+ }
271
+ }
272
+
273
+ // add productArity
274
+ val superProductArityFunction = scalaProductClass.functions.first {
275
+ it.owner.name.asString() == " productArity" &&
276
+ it.owner.valueParameters.isEmpty()
277
+ }
278
+
279
+ val productArityFunction = declaration.addFunction(
280
+ name = " productArity" ,
281
+ returnType = pluginContext.irBuiltIns.intType,
282
+ modality = Modality .OPEN ,
283
+ )
284
+ with (productArityFunction) {
285
+ overriddenSymbols = listOf (superProductArityFunction)
286
+ parent = declaration
287
+
288
+ // add implicit $this parameter
289
+ addDispatchReceiver {
290
+ name = SpecialNames .THIS
291
+ type = declaration.defaultType
292
+ }
293
+
294
+ // add body
295
+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
296
+ val const = properties.size.toIrConst(pluginContext.irBuiltIns.intType)
297
+ + irReturn(const)
298
+ }
299
+ }
300
+
301
+ // add productElement
302
+ val superProductElementFunction = scalaProductClass.functions.first {
303
+ it.owner.name.asString() == " productElement" &&
304
+ it.owner.valueParameters.size == 1 &&
305
+ it.owner.valueParameters.first().type == pluginContext.irBuiltIns.intType
306
+ }
307
+
308
+ val productElementFunction = declaration.addFunction(
309
+ name = " productElement" ,
310
+ returnType = pluginContext.irBuiltIns.anyNType,
311
+ modality = Modality .OPEN ,
312
+ )
313
+ with (productElementFunction) {
314
+ overriddenSymbols = listOf (superProductElementFunction)
315
+ parent = declaration
316
+
317
+ // add implicit $this parameter
318
+ val `this ` = addDispatchReceiver {
319
+ name = SpecialNames .THIS
320
+ type = declaration.defaultType
321
+ }
322
+
323
+ // add n parameter
324
+ val n = addValueParameter(
325
+ name = Name .identifier(" n" ),
326
+ type = pluginContext.irBuiltIns.intType,
327
+ )
328
+
329
+ // add body
330
+ body = pluginContext.irBuiltIns.createIrBuilder(symbol).irBlockBody {
331
+ val whenBranches = buildList {
332
+ for ((i, prop) in properties.withIndex()) {
333
+ val condition = irEquals(
334
+ arg1 = irGet(n),
335
+ arg2 = i.toIrConst(pluginContext.irBuiltIns.intType),
336
+ )
337
+ val call = irCall(prop.getter!! )
338
+ with (call) {
339
+ origin = IrStatementOrigin .GET_PROPERTY
340
+ dispatchReceiver = irGet(`this `)
341
+ }
342
+
343
+ val branch = irBranch(
344
+ condition = condition,
345
+ result = call
346
+ )
347
+ add(branch)
348
+ }
349
+
350
+ val ioobClass = pluginContext.referenceClass(
351
+ FqName (" java.lang.IndexOutOfBoundsException" ).toClassId()
352
+ )!!
353
+ val ioobConstructor = ioobClass.constructors.first { it.owner.valueParameters.isEmpty() }
354
+ val throwCall = irThrow(
355
+ IrConstructorCallImpl .fromSymbolOwner(
356
+ ioobClass.defaultType,
357
+ ioobConstructor
358
+ )
359
+ )
360
+ val elseBranch = irElseBranch(throwCall)
361
+ add(elseBranch)
362
+ }
363
+ val whenBlock = irWhen(pluginContext.irBuiltIns.anyNType, whenBranches)
364
+ with (whenBlock) {
365
+ origin = IrStatementOrigin .IF
366
+ }
367
+ + irReturn(whenBlock)
368
+ }
369
+ }
370
+
371
+ // pass down to the properties
372
+ declaration.acceptChildrenVoid(this )
373
+ }
121
374
}
0 commit comments