Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't reuse symbols between rounds #393

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.google.devtools.ksp.processing.SymbolProcessorProvider
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSName
import me.tatarka.inject.compiler.COMPONENT
import me.tatarka.inject.compiler.InjectGenerator
import me.tatarka.inject.compiler.KMP_COMPONENT_CREATE
Expand All @@ -27,63 +28,89 @@ class InjectProcessor(
private lateinit var provider: KSAstProvider
private lateinit var injectGenerator: InjectGenerator
private lateinit var kmpComponentCreateGenerator: KmpComponentCreateGenerator
private var deferredClasses: List<KSClassDeclaration> = mutableListOf()
private var deferredFunctions: List<KSFunctionDeclaration> = mutableListOf()
private var lastResolver: Resolver? = null
private var deferredClassNames: List<KSName> = mutableListOf()
private var deferredFunctionNames: List<KSName> = mutableListOf()

private val kmpComponentCreateFunctionsByComponentType = mutableMapOf<AstClass, MutableList<AstFunction>>()

override fun process(resolver: Resolver): List<KSAnnotated> {
lastResolver = resolver
provider = KSAstProvider(resolver, logger)
injectGenerator = InjectGenerator(provider, options)
kmpComponentCreateGenerator = KmpComponentCreateGenerator(provider, options)

val previousDeferredClasses = deferredClasses
val previousDeferredFunctions = deferredFunctions

val componentSymbols = previousDeferredClasses + resolver.getSymbolsWithClassAnnotation(
evant marked this conversation as resolved.
Show resolved Hide resolved
packageName = COMPONENT.packageName,
simpleName = COMPONENT.simpleName
)
deferredClasses = componentSymbols.filterNot { element ->
val componentSymbols =
resolver.getSymbolsWithAnnotation(COMPONENT.canonicalName).filterIsInstance<KSClassDeclaration>()
val deferredClasses = componentSymbols.filterNot { element ->
processInject(element, provider, codeGenerator, injectGenerator)
}.toList()
deferredClassNames = deferredClasses.mapNotNull {
val name = it.qualifiedName
if (name == null) {
logger.warn("Unable to defer symbol: ${it.simpleName.asString()}, no qualified name", it)
}
name
}

val kmpComponentCreateSymbols = previousDeferredFunctions + resolver.getSymbolsWithFunctionAnnotation(
packageName = KMP_COMPONENT_CREATE.packageName,
simpleName = KMP_COMPONENT_CREATE.simpleName
)
deferredFunctions = kmpComponentCreateSymbols.filterNot { element ->
val kmpComponentCreateSymbols = resolver.getSymbolsWithAnnotation(KMP_COMPONENT_CREATE.canonicalName)
.filterIsInstance<KSFunctionDeclaration>()
val deferredFunctions = kmpComponentCreateSymbols.filterNot { element ->
processKmpComponentCreate(element, provider, kmpComponentCreateFunctionsByComponentType)
}.toList()
deferredFunctionNames = deferredFunctions.mapNotNull {
val name = it.qualifiedName
if (name == null) {
logger.warn("Unable to defer symbol: ${it.simpleName.asString()}, no qualified name", it)
}
name
}

return deferredClasses + deferredFunctions
}

override fun finish() {
// Last round, generate as much as we can, reporting errors for types that still can't be resolved.
for (element in deferredClasses) {
processInject(
element,
provider,
try {
// Last round, generate as much as we can, reporting errors for types that still can't be resolved.
val resolver = lastResolver ?: return
for (name in deferredClassNames) {
val element = resolver.getClassDeclarationByName(name)
if (element == null) {
logger.error("Failed to resolve: ${name.asString()}")
continue
}
processInject(
element,
provider,
codeGenerator,
injectGenerator,
skipValidation = true
)
}

for (name in deferredFunctionNames) {
val element = resolver.getFunctionDeclarationsByName(
name,
includeTopLevel = true
).firstOrNull()
if (element == null) {
logger.error("Failed to resolve: ${name.asString()}")
continue
}
processKmpComponentCreate(element, provider, kmpComponentCreateFunctionsByComponentType)
}

generateKmpComponentCreateFiles(
codeGenerator,
injectGenerator,
skipValidation = true
kmpComponentCreateGenerator,
kmpComponentCreateFunctionsByComponentType
)
kmpComponentCreateFunctionsByComponentType.clear()
} finally {
lastResolver = null
deferredClassNames = emptyList()
deferredFunctionNames = mutableListOf()
}
deferredClasses = mutableListOf()

for (element in deferredFunctions) {
processKmpComponentCreate(element, provider, kmpComponentCreateFunctionsByComponentType)
}

generateKmpComponentCreateFiles(
codeGenerator,
kmpComponentCreateGenerator,
kmpComponentCreateFunctionsByComponentType
)
kmpComponentCreateFunctionsByComponentType.clear()

deferredFunctions = mutableListOf()
evant marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package me.tatarka.inject.compiler.ksp

import com.google.devtools.ksp.processing.Resolver
import com.google.devtools.ksp.symbol.KSAnnotated
import com.google.devtools.ksp.symbol.KSAnnotation
import com.google.devtools.ksp.symbol.KSClassDeclaration
import com.google.devtools.ksp.symbol.KSDeclaration
import com.google.devtools.ksp.symbol.KSFunctionDeclaration
import com.google.devtools.ksp.symbol.KSType
import com.google.devtools.ksp.symbol.KSTypeParameter
import com.google.devtools.ksp.symbol.KSTypeReference
Expand Down Expand Up @@ -113,47 +110,3 @@ fun KSType.isConcrete(): Boolean {
if (arguments.isEmpty()) return true
return arguments.all { it.type?.resolve()?.isConcrete() ?: false }
}

/**
* A 'fast' version of [Resolver.getSymbolsWithAnnotation]. We only care about class annotations so we can skip a lot
* of the tree.
*/
fun Resolver.getSymbolsWithClassAnnotation(packageName: String, simpleName: String): Sequence<KSClassDeclaration> {
suspend fun SequenceScope<KSClassDeclaration>.visit(declarations: Sequence<KSDeclaration>) {
for (declaration in declarations) {
if (declaration is KSClassDeclaration) {
if (declaration.hasAnnotation(packageName, simpleName)) {
yield(declaration)
}
visit(declaration.declarations)
}
}
}
return sequence {
for (file in getNewFiles()) {
visit(file.declarations)
}
}
}

/**
* A 'fast' version of [Resolver.getSymbolsWithAnnotation]. We only care about function annotations so we can skip a lot
* of the tree.
*/
fun Resolver.getSymbolsWithFunctionAnnotation(
packageName: String,
simpleName: String
): Sequence<KSFunctionDeclaration> {
suspend fun SequenceScope<KSFunctionDeclaration>.visit(declarations: Sequence<KSDeclaration>) {
for (declaration in declarations) {
if (declaration is KSFunctionDeclaration && declaration.hasAnnotation(packageName, simpleName)) {
yield(declaration)
}
}
}
return sequence {
for (file in getNewFiles()) {
visit(file.declarations)
}
}
}