Skip to content

Commit

Permalink
ContributeSubComponent: Support returning Super Type (#83)
Browse files Browse the repository at this point in the history
Co-authored-by: Rick Busarow <rickbusarow@gmail.com>
  • Loading branch information
esafirm and RBusarow authored Jan 1, 2025
1 parent aded86a commit 0f07e65
Show file tree
Hide file tree
Showing 16 changed files with 285 additions and 61 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
**Unreleased**
--------------

### Deprecated
- `ClassReference.functions` has been deprecated in favor of `ClassReference.memberFunctions` and `ClassReference.declaredMemberFunctions`

0.4.0
-----

Expand Down
4 changes: 4 additions & 0 deletions compiler-utils/api/compiler-utils.api
Original file line number Diff line number Diff line change
Expand Up @@ -300,9 +300,11 @@ public abstract class com/squareup/anvil/compiler/internal/reference/ClassRefere
public abstract fun getClassId ()Lorg/jetbrains/kotlin/name/ClassId;
public abstract fun getConstructors ()Ljava/util/List;
public abstract fun getContainingFileAsJavaFile ()Ljava/io/File;
public abstract fun getDeclaredMemberFunctions ()Ljava/util/List;
public abstract fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public abstract fun getFunctions ()Ljava/util/List;
protected abstract fun getInnerClassesAndObjects ()Ljava/util/List;
public final fun getMemberFunctions ()Ljava/util/List;
public abstract fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
public final fun getPackageFqName ()Lorg/jetbrains/kotlin/name/FqName;
public abstract fun getProperties ()Ljava/util/List;
Expand Down Expand Up @@ -330,6 +332,7 @@ public final class com/squareup/anvil/compiler/internal/reference/ClassReference
public final fun getClazz ()Lorg/jetbrains/kotlin/descriptors/ClassDescriptor;
public fun getConstructors ()Ljava/util/List;
public fun getContainingFileAsJavaFile ()Ljava/io/File;
public fun getDeclaredMemberFunctions ()Ljava/util/List;
public fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public fun getFunctions ()Ljava/util/List;
public fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
Expand All @@ -355,6 +358,7 @@ public final class com/squareup/anvil/compiler/internal/reference/ClassReference
public final fun getClazz ()Lorg/jetbrains/kotlin/psi/KtClassOrObject;
public fun getConstructors ()Ljava/util/List;
public fun getContainingFileAsJavaFile ()Ljava/io/File;
public fun getDeclaredMemberFunctions ()Ljava/util/List;
public fun getFqName ()Lorg/jetbrains/kotlin/name/FqName;
public fun getFunctions ()Ljava/util/List;
public fun getModule ()Lcom/squareup/anvil/compiler/internal/reference/AnvilModuleDescriptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,30 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
public val packageFqName: FqName get() = classId.packageFqName

public abstract val constructors: List<MemberFunctionReference>

@Deprecated(
"renamed to `declaredMemberFunctions`. " +
"Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
public abstract val functions: List<MemberFunctionReference>

/**
* All functions that are declared in this class, including overrides.
* This list does not include inherited functions that are not overridden by this class.
*/
public abstract val declaredMemberFunctions: List<MemberFunctionReference>

/**
* All functions declared in this class or any of its super-types.
*/
public val memberFunctions: List<MemberFunctionReference> by lazy(NONE) {
declaredMemberFunctions + directSuperTypeReferences()
.flatMap { it.asClassReference().memberFunctions }
}

public abstract val properties: List<MemberPropertyReference>

public abstract val typeParameters: List<TypeParameterReference>

protected abstract val innerClassesAndObjects: List<ClassReference>
Expand Down Expand Up @@ -146,7 +168,13 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
clazz.containingFileAsJavaFile()
}

override val functions: List<MemberFunctionReference.Psi> by lazy(NONE) {
@Deprecated(
"renamed to `declaredMemberFunctions`. Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
override val functions: List<MemberFunctionReference.Psi> get() = declaredMemberFunctions

override val declaredMemberFunctions: List<MemberFunctionReference.Psi> by lazy(NONE) {
clazz
.children
.filterIsInstance<KtClassBody>()
Expand Down Expand Up @@ -263,9 +291,14 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
)
}

override val functions: List<MemberFunctionReference.Descriptor> by lazy(NONE) {
@Deprecated(
"renamed to `declaredMemberFunctions`. Use `memberFunctions` to include inherited functions.",
replaceWith = ReplaceWith("declaredMemberFunctions"),
)
override val functions: List<MemberFunctionReference.Descriptor> get() = declaredMemberFunctions
override val declaredMemberFunctions: List<MemberFunctionReference.Descriptor> by lazy(NONE) {
clazz.unsubstitutedMemberScope
.getContributedDescriptors(kindFilter = DescriptorKindFilter.FUNCTIONS)
.getDescriptorsFiltered(kindFilter = DescriptorKindFilter.FUNCTIONS)
.filterIsInstance<FunctionDescriptor>()
.filterNot { it is ConstructorDescriptor }
.map { it.toFunctionReference(this) }
Expand All @@ -279,10 +312,8 @@ public sealed class ClassReference : Comparable<ClassReference>, AnnotatedRefere
clazz.unsubstitutedMemberScope
.getDescriptorsFiltered(kindFilter = DescriptorKindFilter.VARIABLES)
.filterIsInstance<PropertyDescriptor>()
.filter {
// Remove inherited properties that aren't overridden in this class.
it.kind == DECLARATION
}
// Remove inherited properties that aren't overridden in this class.
.filter { it.kind == DECLARATION }
.map { it.toPropertyReference(this) }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,13 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
.filter { it.isAbstract }
.toList()

if (functions.size != 1 || functions[0].returnType?.resolve()
?.resolveKSClassDeclaration() != this
) {
val returnType = functions.singleOrNull()?.returnType?.resolve()?.resolveKSClassDeclaration()
if (returnType != this) {

val isReturnSuperType = returnType != null && this.superTypes
.any { type -> type.resolve().resolveKSClassDeclaration() == returnType }
if (isReturnSuperType) return

throw KspAnvilException(
node = factory,
message = "A factory must have exactly one abstract function returning the " +
Expand Down Expand Up @@ -325,7 +329,7 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
)
}

val functions = componentInterface.functions
val functions = componentInterface.memberFunctions
.filter { it.returnType().asClassReference() == this }

if (functions.size >= 2) {
Expand Down Expand Up @@ -378,7 +382,7 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
)
}

val functions = factory.functions
val functions = factory.memberFunctions
.let { functions ->
if (factory.isInterface()) {
functions
Expand All @@ -387,7 +391,13 @@ internal object ContributesSubcomponentCodeGen : AnvilApplicabilityChecker {
}
}

if (functions.size != 1 || functions[0].returnType().asClassReference() != this) {
val returnType = functions.singleOrNull()?.returnType()?.asClassReference()
if (returnType != this) {

val isReturnSuperType = returnType != null && this.directSuperTypeReferences()
.any { it.asClassReference() == returnType }
if (isReturnSuperType) return

throw AnvilCompilationExceptionClassReference(
classReference = factory,
message = "A factory must have exactly one abstract function returning the " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ internal class ContributesSubcomponentHandlerGenerator(
)
}

val functions = componentInterface.functions
val functions = componentInterface.memberFunctions
.filter { it.isAbstract() && it.visibility() == PUBLIC }
.filter {
val returnType = it.returnType().asClassReference()
Expand Down Expand Up @@ -333,9 +333,14 @@ internal class ContributesSubcomponentHandlerGenerator(
)
}

val createComponentFunctions = factory.functions
val createComponentFunctions = factory.memberFunctions
.filter { it.isAbstract() }
.filter { it.returnType().asClassReference().fqName == contributionFqName }
.filter {
val returnType = it.returnType().asClassReference()
returnType.fqName == contributionFqName ||
contribution.clazz.directSuperTypeReferences()
.any { type -> type.asClassReference() == returnType }
}

if (createComponentFunctions.size != 1) {
throw AnvilCompilationExceptionClassReference(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,14 @@ internal class KspContributesSubcomponentHandlerSymbolProcessor(
} else {
function.asMemberOf(implementingType).returnTypeOrNull()
}
returnTypeToCheck
?.resolveKSClassDeclaration()
?.toClassName() == contributionClassName

if (returnTypeToCheck != null) {
val returnTypeClassName = returnTypeToCheck.resolveKSClassDeclaration()?.toClassName()
returnTypeClassName == contributionClassName ||
returnTypeToCheck.isAssignableFrom(contribution.clazz.asType(emptyList()))
} else {
false
}
}
.toList()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ internal object AssistedFactoryCodeGen : AnvilApplicabilityChecker {
val assistedFunctions = allSuperTypeClassReferences(includeSelf = true)
.distinctBy { it.fqName }
.flatMap { clazz ->
clazz.functions
clazz.declaredMemberFunctions
.filter {
it.isAbstract() &&
(it.visibility() == Visibility.PUBLIC || it.visibility() == Visibility.PROTECTED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ internal object BindsMethodValidator : AnvilApplicabilityChecker {
.forEach { clazz ->
(clazz.companionObjects() + clazz)
.asSequence()
.flatMap { it.functions }
.flatMap { it.declaredMemberFunctions }
.filter { it.isAnnotatedWith(daggerBindsFqName) }
.also { functions ->
assertNoDuplicateFunctions(clazz, functions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ internal object ProvidesMethodFactoryCodeGen : AnvilApplicabilityChecker {
.asSequence()

val functions = types
.flatMap { it.functions }
.flatMap { it.declaredMemberFunctions }
.filter { it.isAnnotatedWith(daggerProvidesFqName) }
.onEach { function ->
checkFunctionIsNotAbstract(clazz, function)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ContributesSubcomponentGeneratorTest(
}
}

@Test fun `there is a hint for contributed subcomponents with an interace factory`() {
@Test fun `there is a hint for contributed subcomponents with an interface factory`() {
compile(
"""
package com.squareup.test
Expand Down Expand Up @@ -493,6 +493,84 @@ class ContributesSubcomponentGeneratorTest(
}
}

@Test fun `a factory function may be defined in a super interface`() {
compile(
"""
package com.squareup.test
import com.squareup.anvil.annotations.ContributesSubcomponent
import com.squareup.anvil.annotations.ContributesSubcomponent.Factory
import com.squareup.anvil.annotations.ContributesTo
import com.squareup.anvil.annotations.MergeComponent
@ContributesSubcomponent(Any::class, parentScope = Unit::class)
interface SubcomponentInterface {
@ContributesTo(Unit::class)
interface AnyParentComponent {
fun createFactory(): ComponentFactory
}
interface Creator {
fun createComponent(): SubcomponentInterface
}
@Factory
interface ComponentFactory : Creator
}
@MergeComponent(Unit::class)
interface ComponentInterface
""",
mode = mode,
) {
assertThat(subcomponentInterface.hintSubcomponent?.java).isEqualTo(subcomponentInterface)
assertThat(subcomponentInterface.hintSubcomponentParentScope).isEqualTo(Unit::class)

assertThat(subcomponentInterface.componentFactoryInterface.methods.map { it.name })
.containsExactly("createComponent")
}
}

@Test fun `a factory function may returns the component super type`() {
compile(
"""
package com.squareup.test
import com.squareup.anvil.annotations.ContributesSubcomponent
import com.squareup.anvil.annotations.ContributesSubcomponent.Factory
import com.squareup.anvil.annotations.ContributesTo
import com.squareup.anvil.annotations.MergeComponent
interface BaseSubcomponentInterface {
interface Factory {
fun createComponent(): BaseSubcomponentInterface
}
}
@ContributesSubcomponent(Any::class, parentScope = Unit::class)
interface SubcomponentInterface : BaseSubcomponentInterface {
@Factory
interface ComponentFactory: BaseSubcomponentInterface.Factory
@ContributesTo(Unit::class)
interface ParentComponent {
fun createFactory(): ComponentFactory
}
}
@MergeComponent(Unit::class)
interface ComponentInterface
""",
mode = mode,
) {
assertThat(subcomponentInterface.hintSubcomponent?.java).isEqualTo(subcomponentInterface)
assertThat(subcomponentInterface.hintSubcomponentParentScope).isEqualTo(Unit::class)

assertThat(subcomponentInterface.componentFactoryInterface.methods.map { it.name })
.containsExactly("createComponent")
}
}

@Test
fun `using Dagger's @Subcomponent_Factory is an error`() {
compile(
Expand Down Expand Up @@ -616,6 +694,9 @@ class ContributesSubcomponentGeneratorTest(
private val Class<*>.parentComponentInterface: Class<*>
get() = classLoader.loadClass("$canonicalName\$AnyParentComponent")

private val Class<*>.componentFactoryInterface: Class<*>
get() = classLoader.loadClass("$canonicalName\$ComponentFactory")

private val JvmCompilationResult.subcomponentInterface1: Class<*>
get() = classLoader.loadClass("com.squareup.test.SubcomponentInterface1")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,13 @@ class ClassReferenceTest {
assertThat(psiRef.isGenericClass()).isFalse()
assertThat(descriptorRef.isGenericClass()).isFalse()

assertThat(psiRef.functions.single().returnType().isGenericType()).isFalse()
assertThat(
descriptorRef.functions.single { it.name == "string" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isFalse()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "string" }
.returnType()
.isGenericType(),
).isFalse()
Expand All @@ -223,17 +227,25 @@ class ClassReferenceTest {
).isTrue()
}
"SomeClass3" -> {
assertThat(psiRef.functions.single().returnType().isGenericType()).isTrue()
assertThat(
descriptorRef.functions.single { it.name == "list" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isTrue()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "list" }
.returnType()
.isGenericType(),
).isTrue()
}
"SomeClass4" -> {
assertThat(psiRef.functions.single().returnType().isGenericType()).isTrue()
assertThat(
descriptorRef.functions.single { it.name == "list" }
psiRef.declaredMemberFunctions.single()
.returnType()
.isGenericType(),
).isTrue()
assertThat(
descriptorRef.declaredMemberFunctions.single { it.name == "list" }
.returnType()
.isGenericType(),
).isTrue()
Expand Down
Loading

0 comments on commit 0f07e65

Please sign in to comment.