From 392181d7ee6b9b6d566daaa1a2bfc732d17c2bb2 Mon Sep 17 00:00:00 2001 From: Igor Chevdar Date: Fri, 12 Apr 2024 15:22:32 +0300 Subject: [PATCH] [K/N] Devirtualization: fixed the problem with type checks ... during selection of proper callee #KT-67218 Fixed --- .../optimizations/DevirtualizationAnalysis.kt | 120 ++++++++++++------ 1 file changed, 83 insertions(+), 37 deletions(-) diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt index 42778f8d115d4..28b2ebf824f58 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/DevirtualizationAnalysis.kt @@ -17,7 +17,6 @@ import org.jetbrains.kotlin.backend.konan.util.IntArrayList import org.jetbrains.kotlin.backend.konan.util.LongArrayList import org.jetbrains.kotlin.backend.konan.lower.getObjectClassInstanceFunction import org.jetbrains.kotlin.descriptors.ClassKind -import org.jetbrains.kotlin.descriptors.explicitParameters import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.builders.* import org.jetbrains.kotlin.ir.declarations.* @@ -1546,6 +1545,57 @@ internal object DevirtualizationAnalysis { } else -> irBlock(expression) { + /* + * More than one possible callee - need to select the proper one. + * There are two major cases here: + * - there is only one possible receiver type, and all what is needed is just compare the type infos + * - otherwise, there are multiple receiver types (meaning the actual callee has not been overridden in + * the inheritors), and a full type check operation is required. + * These checks cannot be performed in arbitrary order - the check for a derived type must be + * performed before the check for the base type. + * To improve performance, we try to perform these checks in the following order: first, those with only one + * receiver, then classes type checks, and finally interface type checks. + * Note: performing the slowest check last allows to place it to else clause and skip it improving performance. + * The actual order in which perform these checks is found by a simple back tracking algorithm + * (since the number of possible callees is small, it is ok in terms of performance). + */ + + data class Target(val actualCallee: DataFlowIR.FunctionSymbol.Declared, val possibleReceivers: List) { + val declType = actualCallee.irFunction!!.parentAsClass + val weight = when { + possibleReceivers.size == 1 -> 0 // The fastest. + declType.isInterface -> 2 // The slowest. + else -> 1 // In between. + } + var used = false + } + + val targets = possibleCallees.map { Target(it.first, it.second) } + var bestOrder: List? = null + var bestLexOrder = Int.MAX_VALUE + fun backTrack(order: List, lexOrder: Int) { + if (order.size == targets.size) { + if (lexOrder < bestLexOrder) { + bestOrder = order + bestLexOrder = lexOrder + } + return + } + for (target in targets.filterNot { it.used }) { + val fitsAsNext = order.none { target.declType.isSubclassOf(it.declType) } + if (!fitsAsNext) continue + val nextOrder = order + target + // Don't count the last one since it will be in the else clause. + val nextLexOrder = if (nextOrder.size == targets.size) lexOrder else lexOrder * 3 + target.weight + target.used = true + backTrack(nextOrder, nextLexOrder) + target.used = false + } + } + + backTrack(emptyList(), 0) + require(bestLexOrder != Int.MAX_VALUE) // Should never happen since there are no cycles in a type hierarchy. + val arguments = expression.getArgumentsWithIr().mapIndexed { index, arg -> irSplitCoercion(caller, arg.second, "arg$index", arg.first.type) } @@ -1555,45 +1605,41 @@ internal object DevirtualizationAnalysis { putValueArgument(0, irGet(receiver)) }) } - val branches = mutableListOf() - possibleCallees - // Try to leave the most complicated case for the last, - // and, hopefully, place it in the else clause. - .sortedBy { it.second.size } - .mapIndexedTo(branches) { index, devirtualizedCallee -> - val (actualCallee, receiverTypes) = devirtualizedCallee - val condition = - if (optimize && index == possibleCallees.size - 1) - irTrue() // Don't check last type in optimize mode. - else { - if (receiverTypes.size == 1) { - // It is faster to just compare type infos instead of a full type check. - val receiverType = receiverTypes[0] - val expectedTypeInfo = IrClassReferenceImpl( - startOffset, endOffset, - symbols.nativePtrType, - receiverType.irClass!!.symbol, - receiverType.irClass.defaultType - ) - irCall(nativePtrEqualityOperatorSymbol).apply { - putValueArgument(0, irGet(typeInfo)) - putValueArgument(1, expectedTypeInfo) - } - } else { - val receiverType = actualCallee.irFunction!!.parentAsClass - irCall(isSubtype, listOf(receiverType.defaultType)).apply { - putValueArgument(0, irGet(typeInfo)) - } - } - } - IrBranchImpl( - startOffset = startOffset, - endOffset = endOffset, - condition = condition, - result = irDevirtualizedCall(expression, type, actualCallee, arguments) + bestOrder!!.mapIndexedTo(branches) { index, target -> + val (actualCallee, receiverTypes) = target + val condition = when { + optimize && index == possibleCallees.size - 1 -> { + // Don't check the last type in optimize mode. + irTrue() + } + receiverTypes.size == 1 -> { + // It is faster to just compare type infos instead of a full type check. + val receiverType = receiverTypes[0] + val expectedTypeInfo = IrClassReferenceImpl( + startOffset, endOffset, + symbols.nativePtrType, + receiverType.irClass!!.symbol, + receiverType.irClass.defaultType ) + irCall(nativePtrEqualityOperatorSymbol).apply { + putValueArgument(0, irGet(typeInfo)) + putValueArgument(1, expectedTypeInfo) + } + } + else -> { + irCall(isSubtype, listOf(target.declType.defaultType)).apply { + putValueArgument(0, irGet(typeInfo)) + } } + } + IrBranchImpl( + startOffset = startOffset, + endOffset = endOffset, + condition = condition, + result = irDevirtualizedCall(expression, type, actualCallee, arguments) + ) + } if (!optimize) { // Add else branch throwing exception for debug purposes. branches.add(IrBranchImpl( startOffset = startOffset,