Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for top-level Kotlin functions #847 #1147

Merged
merged 6 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions utbot-core/src/main/kotlin/org/utbot/common/KClassUtil.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,9 @@ package org.utbot.common

import java.lang.reflect.InvocationTargetException
import java.lang.reflect.Method
import kotlin.reflect.KClass

val Class<*>.nameOfPackage: String get() = `package`?.name?:""

fun Method.invokeCatching(obj: Any?, args: List<Any?>) = try {
Result.success(invoke(obj, *args.toTypedArray()))
} catch (e: InvocationTargetException) {
Result.failure<Nothing>(e.targetException)
}

val KClass<*>.allNestedClasses: List<KClass<*>>
get() = listOf(this) + nestedClasses.flatMap { it.allNestedClasses }
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,10 @@ val Class<*>.isFinal
get() = Modifier.isFinal(modifiers)

val Class<*>.isProtected
get() = Modifier.isProtected(modifiers)
get() = Modifier.isProtected(modifiers)

val Class<*>.nameOfPackage: String
get() = `package`?.name?:""

val Class<*>.allNestedClasses: List<Class<*>>
get() = listOf(this) + this.declaredClasses.flatMap { it.allNestedClasses }
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ import kotlin.reflect.KCallable
import kotlin.reflect.KClass
import kotlin.reflect.KFunction
import kotlin.reflect.KProperty
import kotlin.reflect.full.extensionReceiverParameter
import kotlin.reflect.full.instanceParameter
import kotlin.reflect.jvm.internal.impl.load.kotlin.header.KotlinClassHeader
import kotlin.reflect.jvm.javaConstructor
import kotlin.reflect.jvm.javaField
import kotlin.reflect.jvm.javaGetter
import kotlin.reflect.jvm.javaMethod
import kotlin.reflect.jvm.kotlinFunction

// ClassId utils

Expand Down Expand Up @@ -178,6 +181,14 @@ val ClassId.isDoubleType: Boolean
val ClassId.isClassType: Boolean
get() = this == classClassId

/**
* Checks if the class is a Kotlin class with kind File (see [Metadata.kind] for more details)
*/
val ClassId.isKotlinFile: Boolean
get() = jClass.annotations.filterIsInstance<Metadata>().singleOrNull()?.let {
KotlinClassHeader.Kind.getById(it.kind) == KotlinClassHeader.Kind.FILE_FACADE
} ?: false

val voidClassId = ClassId("void")
val booleanClassId = ClassId("boolean")
val byteClassId = ClassId("byte")
Expand Down Expand Up @@ -430,6 +441,12 @@ val MethodId.method: Method
?: error("Can't find method $signature in ${declaringClass.name}")
}

/**
* See [KCallable.extensionReceiverParameter] for more details
*/
val MethodId.extensionReceiverParameterIndex: Int?
volivan239 marked this conversation as resolved.
Show resolved Hide resolved
get() = this.method.kotlinFunction?.extensionReceiverParameter?.index

// TODO: maybe cache it somehow in the future
val ConstructorId.constructor: Constructor<*>
get() {
Expand Down Expand Up @@ -484,6 +501,7 @@ val Method.displayName: String

val KCallable<*>.declaringClazz: Class<*>
get() = when (this) {
is KFunction<*> -> javaMethod?.declaringClass?.kotlin
EgorkaKulikov marked this conversation as resolved.
Show resolved Hide resolved
is CallableReference -> owner as? KClass<*>
else -> instanceParameter?.type?.classifier as? KClass<*>
}?.java ?: tryConstructor(this) ?: error("Can't get parent class for $this")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.utbot.examples.codegen

import org.junit.jupiter.api.Test
import org.utbot.testcheckers.eq
import org.utbot.tests.infrastructure.UtValueTestCaseChecker
import kotlin.reflect.KFunction3

@Suppress("UNCHECKED_CAST")
internal class FileWithTopLevelFunctionsTest : UtValueTestCaseChecker(testClass = FileWithTopLevelFunctionsReflectHelper.clazz.kotlin) {
@Test
fun topLevelSumTest() {
check(
::topLevelSum,
eq(1),
)
}

@Test
fun extensionOnBasicTypeTest() {
check(
Int::extensionOnBasicType,
eq(1),
)
}

@Test
fun extensionOnCustomClassTest() {
check(
// NB: cast is important here because we need to treat receiver as an argument to be able to check its content in matchers
CustomClass::extensionOnCustomClass as KFunction3<*, CustomClass, CustomClass, Boolean>,
eq(2),
{ receiver, argument, result -> receiver === argument && result == true },
{ receiver, argument, result -> receiver !== argument && result == false },
additionalDependencies = dependenciesForClassExtensions
)
}

companion object {
// Compilation of extension methods for ref objects produces call to
// `kotlin.jvm.internal.Intrinsics::checkNotNullParameter`, so we need to add it to dependencies
val dependenciesForClassExtensions = arrayOf<Class<*>>(kotlin.jvm.internal.Intrinsics::class.java)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ import org.utbot.framework.plugin.api.MethodId
import org.utbot.framework.plugin.api.UtExplicitlyThrownException
import org.utbot.framework.plugin.api.util.isStatic
import org.utbot.framework.plugin.api.util.exceptions
import org.utbot.framework.plugin.api.util.extensionReceiverParameterIndex
import org.utbot.framework.plugin.api.util.humanReadableName
import org.utbot.framework.plugin.api.util.id
import org.utbot.framework.plugin.api.util.isArray
import org.utbot.framework.plugin.api.util.isSubtypeOf
Expand Down Expand Up @@ -110,7 +112,7 @@ internal class CgCallableAccessManagerImpl(val context: CgContext) : CgCallableA
override operator fun CgIncompleteMethodCall.invoke(vararg args: Any?): CgMethodCall {
val resolvedArgs = args.resolve()
val methodCall = if (method.canBeCalledWith(caller, resolvedArgs)) {
CgMethodCall(caller, method, resolvedArgs.guardedForDirectCallOf(method))
CgMethodCall(caller, method, resolvedArgs.guardedForDirectCallOf(method)).takeCallerFromArgumentsIfNeeded()
EgorkaKulikov marked this conversation as resolved.
Show resolved Hide resolved
} else {
method.callWithReflection(caller, resolvedArgs)
}
Expand Down Expand Up @@ -194,6 +196,29 @@ internal class CgCallableAccessManagerImpl(val context: CgContext) : CgCallableA
else -> false
}

/**
* For Kotlin extension functions, real caller is one of the arguments in JVM method (and declaration class is omitted),
* thus we should move it from arguments to caller
*
* For example, if we have `Int.f(a: Int)` declared in `Main.kt`, the JVM method signature will be `MainKt.f(Int, Int)`
* and in Kotlin we should render this not like `MainKt.f(a, b)` but like `a.f(b)`
*/
private fun CgMethodCall.takeCallerFromArgumentsIfNeeded(): CgMethodCall {
if (codegenLanguage == CodegenLanguage.KOTLIN) {
// TODO: reflection calls for util and some of mockito methods produce exceptions => here we suppose that
// methods for BuiltinClasses are not extensions by default (which should be true as long as we suppose them to be java methods)
if (executableId.classId !is BuiltinClassId) {
executableId.extensionReceiverParameterIndex?.let { receiverIndex ->
require(caller == null) { "${executableId.humanReadableName} is an extension function but it already has a non-static caller provided" }
val args = arguments.toMutableList()
return CgMethodCall(args.removeAt(receiverIndex), executableId, args, typeParameters)
}
}
}

return this
}

private infix fun CgExpression.canBeArgOf(type: ClassId): Boolean {
// TODO: SAT-1210 support generics so that we wouldn't need to check specific cases such as this one
if (this is CgExecutableCall && (executableId == any || executableId == anyOfClass)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,13 @@ internal abstract class CgAbstractRenderer(
}
}

/**
* Returns true if one can call methods of this class without specifying a caller (for example if ClassId represents this instance)
*/
protected abstract val ClassId.methodsAreAccessibleAsTopLevel: Boolean

private val MethodId.accessibleByName: Boolean
get() = (context.shouldOptimizeImports && this in context.importedStaticMethods) || classId == context.generatedClass
get() = (context.shouldOptimizeImports && this in context.importedStaticMethods) || classId.methodsAreAccessibleAsTopLevel

override fun visit(element: CgElement) {
val error =
Expand Down Expand Up @@ -654,8 +659,10 @@ internal abstract class CgAbstractRenderer(
}

override fun visit(element: CgStaticFieldAccess) {
print(element.declaringClass.asString())
print(".")
if (!element.declaringClass.methodsAreAccessibleAsTopLevel) {
print(element.declaringClass.asString())
print(".")
}
print(element.fieldName)
}

Expand Down Expand Up @@ -707,7 +714,10 @@ internal abstract class CgAbstractRenderer(
if (caller != null) {
// 'this' can be omitted, otherwise render caller
if (caller !is CgThisInstance) {
// TODO: we need parentheses for calls like (-1).inv(), do something smarter here
if (caller !is CgVariable) print("(")
EgorkaKulikov marked this conversation as resolved.
Show resolved Hide resolved
caller.accept(this)
if (caller !is CgVariable) print(")")
renderAccess(caller)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ internal class CgJavaRenderer(context: CgRendererContext, printer: CgPrinter = C

override val langPackage: String = "java.lang"

override val ClassId.methodsAreAccessibleAsTopLevel: Boolean
get() = this == context.generatedClass

override fun visit(element: AbstractCgClass<*>) {
for (annotation in element.annotations) {
annotation.accept(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ import org.utbot.framework.plugin.api.TypeParameters
import org.utbot.framework.plugin.api.WildcardTypeParameter
import org.utbot.framework.plugin.api.util.id
import org.utbot.framework.plugin.api.util.isArray
import org.utbot.framework.plugin.api.util.isKotlinFile
import org.utbot.framework.plugin.api.util.isPrimitive
import org.utbot.framework.plugin.api.util.isPrimitiveWrapper
import org.utbot.framework.plugin.api.util.kClass
Expand All @@ -72,6 +73,10 @@ internal class CgKotlinRenderer(context: CgRendererContext, printer: CgPrinter =

override val langPackage: String = "kotlin"

override val ClassId.methodsAreAccessibleAsTopLevel: Boolean
// NB: the order of operands is important as `isKotlinFile` uses reflection and thus can't be called on context.generatedClass
get() = (this == context.generatedClass) || isKotlinFile

override fun visit(element: AbstractCgClass<*>) {
for (annotation in element.annotations) {
annotation.accept(this)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ import kotlin.reflect.KFunction
import kotlin.reflect.KParameter
import kotlin.reflect.jvm.javaType

// Note that rules for obtaining signature here should correlate with PsiMethod.signature()
fun KFunction<*>.signature() =
Signature(this.name, this.parameters.filter { it.kind == KParameter.Kind.VALUE }.map { it.type.javaType.typeName })
Signature(this.name, this.parameters.filter { it.kind != KParameter.Kind.INSTANCE }.map { it.type.javaType.typeName })

data class Signature(val name: String, val parameterTypes: List<String?>) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.utbot.framework.plugin.api.util.UtContext
import org.utbot.framework.plugin.api.util.executableId
import org.utbot.framework.plugin.api.util.id
import org.utbot.framework.plugin.api.util.jClass
import org.utbot.framework.plugin.api.util.method
import org.utbot.framework.plugin.services.JdkInfo
import org.utbot.framework.process.generated.*
import org.utbot.framework.util.ConflictTriggers
Expand All @@ -36,7 +37,7 @@ import org.utbot.summary.summarize
import java.io.File
import java.net.URLClassLoader
import java.nio.file.Paths
import kotlin.reflect.full.functions
import kotlin.reflect.jvm.kotlinFunction
import kotlin.time.Duration.Companion.seconds

private val messageFromMainTimeoutMillis = 120.seconds
Expand Down Expand Up @@ -158,8 +159,8 @@ private fun EngineProcessModel.setup(
synchronizer.measureExecutionForTermination(findMethodsInClassMatchingSelected) { params ->
val classId = kryoHelper.readObject<ClassId>(params.classId)
val selectedSignatures = params.signatures.map { Signature(it.name, it.parametersTypes) }
FindMethodsInClassMatchingSelectedResult(kryoHelper.writeObject(classId.jClass.kotlin.allNestedClasses.flatMap { clazz ->
clazz.functions.sortedWith(compareBy { selectedSignatures.indexOf(it.signature()) })
FindMethodsInClassMatchingSelectedResult(kryoHelper.writeObject(classId.jClass.allNestedClasses.flatMap { clazz ->
clazz.id.allMethods.mapNotNull { it.method.kotlinFunction }.sortedWith(compareBy { selectedSignatures.indexOf(it.signature()) })
.filter { it.signature().normalized() in selectedSignatures }
.map { it.executableId }
}))
Expand All @@ -168,7 +169,7 @@ private fun EngineProcessModel.setup(
val classId = kryoHelper.readObject<ClassId>(params.classId)
val bySignature = kryoHelper.readObject<Map<Signature, List<String>>>(params.bySignature)
FindMethodParamNamesResult(kryoHelper.writeObject(
classId.jClass.kotlin.allNestedClasses.flatMap { it.functions }
classId.jClass.allNestedClasses.flatMap { clazz -> clazz.id.allMethods.mapNotNull { it.method.kotlinFunction } }
.mapNotNull { method -> bySignature[method.signature()]?.let { params -> method.executableId to params } }
.toMap()
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,16 @@ class EngineProcess(parent: Lifetime, val project: Project) {
}

private fun MemberInfo.paramNames(): List<String> =
(this.member as PsiMethod).parameterList.parameters.map { it.name }
(this.member as PsiMethod).parameterList.parameters.map {
if (it.name.startsWith("\$this"))
// If member is Kotlin extension function, name of first argument isn't good for further usage,
// so we better choose name based on type of receiver.
//
// There seems no API to check whether parameter is an extension receiver by PSI
it.type.presentableText
else
it.name
}

fun generate(
mockInstalled: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.*
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.refactoring.util.classMembers.MemberInfo
import org.jetbrains.kotlin.asJava.findFacadeClass
import org.jetbrains.kotlin.idea.core.getPackage
import org.jetbrains.kotlin.idea.core.util.toPsiDirectory
import org.jetbrains.kotlin.idea.core.util.toPsiFile
Expand All @@ -26,6 +27,7 @@ import org.utbot.intellij.plugin.util.extractFirstLevelMembers
import org.utbot.intellij.plugin.util.isVisible
import java.util.*
import org.jetbrains.kotlin.j2k.getContainingClass
import org.jetbrains.kotlin.psi.KtFile
import org.jetbrains.kotlin.utils.addIfNotNull
import org.utbot.framework.plugin.api.util.LockFile
import org.utbot.intellij.plugin.models.packageName
Expand Down Expand Up @@ -218,7 +220,7 @@ class GenerateTestsAction : AnAction(), UpdateInBackground {
}

private fun getAllClasses(directory: PsiDirectory): Set<PsiClass> {
val allClasses = directory.files.flatMap { getClassesFromFile(it) }.toMutableSet()
val allClasses = directory.files.flatMap { PsiElementHandler.makePsiElementHandler(it).getClassesFromFile(it) }.toMutableSet()
for (subDir in directory.subdirectories) allClasses += getAllClasses(subDir)
return allClasses
}
Expand All @@ -231,15 +233,10 @@ class GenerateTestsAction : AnAction(), UpdateInBackground {
if (!dirsArePackages) {
return emptySet()
}
val allClasses = psiFiles.flatMap { getClassesFromFile(it) }.toMutableSet()
val allClasses = psiFiles.flatMap { PsiElementHandler.makePsiElementHandler(it).getClassesFromFile(it) }.toMutableSet()
allClasses.addAll(psiFiles.mapNotNull { (it as? KtFile)?.findFacadeClass() })
EgorkaKulikov marked this conversation as resolved.
Show resolved Hide resolved
for (psiDir in psiDirectories) allClasses += getAllClasses(psiDir)

return allClasses
}

private fun getClassesFromFile(psiFile: PsiFile): List<PsiClass> {
val psiElementHandler = PsiElementHandler.makePsiElementHandler(psiFile)
return PsiTreeUtil.getChildrenOfTypeAsList(psiFile, psiElementHandler.classClass)
.map { psiElementHandler.toPsi(it, PsiClass::class.java) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package org.utbot.intellij.plugin.ui.utils

import com.intellij.psi.PsiClass
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.util.findParentOfType
import org.jetbrains.kotlin.asJava.findFacadeClass
import org.jetbrains.kotlin.idea.testIntegration.KotlinCreateTestIntention
import org.jetbrains.kotlin.psi.KtClass
import org.jetbrains.kotlin.psi.KtClassOrObject
Expand All @@ -24,13 +27,27 @@ class KotlinPsiElementHandler(
return element.toUElement()?.javaPsi as? T ?: error("Could not cast $element to $clazz")
}

override fun isCreateTestActionAvailable(element: PsiElement): Boolean =
getTarget(element)?.let { KotlinCreateTestIntention().applicabilityRange(it) != null } ?: false
override fun getClassesFromFile(psiFile: PsiFile): List<PsiClass> {
return listOfNotNull((psiFile as? KtFile)?.findFacadeClass()) + super.getClassesFromFile(psiFile)
}

override fun isCreateTestActionAvailable(element: PsiElement): Boolean {
getTarget(element)?.let {
return KotlinCreateTestIntention().applicabilityRange(it) != null
}
return (element.containingFile as? KtFile)?.findFacadeClass() != null
}

private fun getTarget(element: PsiElement?): KtNamedDeclaration? =
element?.parentsWithSelf
?.firstOrNull { it is KtClassOrObject || it is KtNamedDeclaration && it.parent is KtFile } as? KtNamedDeclaration

override fun containingClass(element: PsiElement): PsiClass? =
element.parentsWithSelf.firstOrNull { it is KtClassOrObject }?.let { toPsi(it, PsiClass::class.java) }
override fun containingClass(element: PsiElement): PsiClass? {
element.findParentOfType<KtClassOrObject>(strict=false)?.let {
return toPsi(it, PsiClass::class.java)
}
return element.findParentOfType<KtFile>(strict=false)?.findFacadeClass()?.let {
toPsi(it, PsiClass::class.java)
}
}
}
Loading