Skip to content

Commit

Permalink
KSP: Mark dependencies are aggregated if the visitor is of aggregated…
Browse files Browse the repository at this point in the history
… type (#10487)
  • Loading branch information
dstepanov authored Feb 13, 2024
1 parent faddb47 commit bdd7e68
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import java.io.File
import java.io.OutputStream
import java.util.*

internal class KotlinOutputVisitor(private val environment: SymbolProcessorEnvironment): AbstractClassWriterOutputVisitor(false) {
internal class KotlinOutputVisitor(private val environment: SymbolProcessorEnvironment, private val context: KotlinVisitorContext): AbstractClassWriterOutputVisitor(false) {

override fun visitClass(classname: String, vararg originatingElements: Element): OutputStream {
return environment.codeGenerator.createNewFile(
Expand Down Expand Up @@ -86,6 +86,6 @@ internal class KotlinOutputVisitor(private val environment: SymbolProcessorEnvir
} else {
emptyArray()
}
return Dependencies(aggregating = originatingElements.size > 1, sources = sources)
return Dependencies(aggregating = context.aggregating || originatingElements.size > 1, sources = sources)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ internal class BeanDefinitionProcessor(private val environment: SymbolProcessorE

override fun finish() {
try {
val outputVisitor = KotlinOutputVisitor(environment)
val outputVisitor = KotlinOutputVisitor(environment, visitorContext!!)
val processed = HashSet<String>()
var count = 0
for (beanDefinitionCreator in beanDefinitionMap.values) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,27 +48,20 @@ import java.util.function.BiConsumer
import kotlin.collections.ArrayList

@OptIn(KspExperimental::class)
internal open class KotlinVisitorContext(
internal class KotlinVisitorContext(
private val environment: SymbolProcessorEnvironment,
var resolver: Resolver
) : VisitorContext {

private val visitorAttributes: MutableConvertibleValues<Any>
private val elementFactory: KotlinElementFactory
private val outputVisitor = KotlinOutputVisitor(environment)
val annotationMetadataBuilder: KotlinAnnotationMetadataBuilder
private val elementAnnotationMetadataFactory: KotlinElementAnnotationMetadataFactory
private val expressionCompilationContextFactory : ExpressionCompilationContextFactory
private val visitorAttributes: MutableConvertibleValues<Any> = MutableConvertibleValuesMap()
private val elementFactory: KotlinElementFactory = KotlinElementFactory(this)
private val outputVisitor = KotlinOutputVisitor(environment, this)
val annotationMetadataBuilder = KotlinAnnotationMetadataBuilder(environment, resolver, this)
private val elementAnnotationMetadataFactory = KotlinElementAnnotationMetadataFactory(false, annotationMetadataBuilder)
private val expressionCompilationContextFactory = DefaultExpressionCompilationContextFactory(this)
val nativeElementsHelper = KotlinNativeElementsHelper(resolver)

var aggregating: Boolean = false
init {
visitorAttributes = MutableConvertibleValuesMap()
annotationMetadataBuilder = KotlinAnnotationMetadataBuilder(environment, resolver, this)
elementFactory = KotlinElementFactory(this)
elementAnnotationMetadataFactory =
KotlinElementAnnotationMetadataFactory(false, annotationMetadataBuilder)
expressionCompilationContextFactory = DefaultExpressionCompilationContextFactory(this)

try {
// Workaround for bug in KSP https://github.com/google/ksp/issues/1493
val resolverImplClass = ClassUtils.forName("com.google.devtools.ksp.processing.impl.ResolverImpl", javaClass.classLoader).orElseThrow()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ internal open class TypeElementSymbolProcessor(private val environment: SymbolPr
// before RepositoryMapper is going to process repositories and read entities

for (loadedVisitor in loadedVisitors) {
visitorContext.aggregating = loadedVisitor.visitor.visitorKind == TypeElementVisitor.VisitorKind.AGGREGATING
for (typeElement in elements) {
if (!loadedVisitor.matches(typeElement)) {
continue
Expand Down Expand Up @@ -155,6 +156,7 @@ internal open class TypeElementSymbolProcessor(private val environment: SymbolPr

override fun finish() {
for (loadedVisitor in loadedVisitors) {
visitorContext.aggregating = loadedVisitor.visitor.visitorKind == TypeElementVisitor.VisitorKind.AGGREGATING
try {
loadedVisitor.visitor.finish(visitorContext)
} catch (e: ProcessingException) {
Expand Down

0 comments on commit bdd7e68

Please sign in to comment.