diff --git a/java/dagger/internal/codegen/binding/Nullability.java b/java/dagger/internal/codegen/binding/Nullability.java index e916e16818c..4231a8fb067 100644 --- a/java/dagger/internal/codegen/binding/Nullability.java +++ b/java/dagger/internal/codegen/binding/Nullability.java @@ -55,13 +55,16 @@ public static Nullability of(XElement element) { } private static ImmutableSet getNullableAnnotations(XElement element) { - return getNullableAnnotations(element.getAllAnnotations().stream()); + return getNullableAnnotations(element.getAllAnnotations().stream(), ImmutableSet.of()); } - private static ImmutableSet getNullableAnnotations(Stream annotations) { + private static ImmutableSet getNullableAnnotations( + Stream annotations, + ImmutableSet filterSet) { return annotations .map(XAnnotations::getClassName) .filter(annotation -> annotation.simpleName().contentEquals("Nullable")) + .filter(annotation -> !filterSet.contains(annotation)) .collect(toImmutableSet()); } diff --git a/java/dagger/internal/codegen/writing/FactoryGenerator.java b/java/dagger/internal/codegen/writing/FactoryGenerator.java index 9fba7cd0907..32195ba6a3e 100644 --- a/java/dagger/internal/codegen/writing/FactoryGenerator.java +++ b/java/dagger/internal/codegen/writing/FactoryGenerator.java @@ -248,7 +248,6 @@ private MethodSpec getMethod(ProvisionBinding binding) { MethodSpec.Builder getMethod = methodBuilder("get") .addModifiers(PUBLIC) - .returns(providedTypeName) .addParameters(assistedParameters.values()); if (factoryTypeName(binding).isPresent()) { @@ -270,10 +269,12 @@ private MethodSpec getMethod(ProvisionBinding binding) { .nullability() .nullableAnnotations() .forEach(getMethod::addAnnotation); + getMethod.returns(providedTypeName); getMethod.addStatement("return $L", invokeNewInstance); } else if (!binding.injectionSites().isEmpty()) { CodeBlock instance = CodeBlock.of("instance"); getMethod + .returns(providedTypeName) .addStatement("$T $L = $L", providedTypeName, instance, invokeNewInstance) .addCode( InjectionSiteMethod.invokeAll( @@ -283,8 +284,11 @@ private MethodSpec getMethod(ProvisionBinding binding) { binding.key().type().xprocessing(), sourceFiles.frameworkFieldUsages(binding.dependencies(), frameworkFields)::get)) .addStatement("return $L", instance); + } else { - getMethod.addStatement("return $L", invokeNewInstance); + getMethod + .returns(providedTypeName) + .addStatement("return $L", invokeNewInstance); } return getMethod.build(); } diff --git a/java/dagger/internal/codegen/writing/InjectionMethods.java b/java/dagger/internal/codegen/writing/InjectionMethods.java index f4752f2c665..39c03dd7ff7 100644 --- a/java/dagger/internal/codegen/writing/InjectionMethods.java +++ b/java/dagger/internal/codegen/writing/InjectionMethods.java @@ -405,7 +405,8 @@ private static MethodSpec methodProxy( if (isVoid(method.getReturnType())) { return builder.addStatement("$L", invocation).build(); } else { - Nullability.of(method) + Nullability nullability = Nullability.of(method); + nullability .nullableAnnotations() .forEach(builder::addAnnotation); return builder diff --git a/javatests/dagger/functional/nullables/JspecifyNullableTest.java b/javatests/dagger/functional/nullables/JspecifyNullableTest.java index 63b4a9ae4e8..b5d42f3b705 100644 --- a/javatests/dagger/functional/nullables/JspecifyNullableTest.java +++ b/javatests/dagger/functional/nullables/JspecifyNullableTest.java @@ -22,6 +22,7 @@ import dagger.Component; import dagger.Module; import dagger.Provides; +import javax.inject.Provider; import org.jspecify.annotations.Nullable; import org.junit.Test; import org.junit.runner.RunWith; @@ -29,40 +30,92 @@ @RunWith(JUnit4.class) public final class JspecifyNullableTest { - @Component(modules = MyModule.class) + @Component(modules = MyModule.class, dependencies = ComponentDependency.class) interface MyComponent { Integer getInt(); + InnerType getInnerType(); + Provider getDependencyProvider(); } + interface Dependency {} + + interface InnerType {} + @Module static class MyModule { - private final Integer value; + private final Integer integer; + private final InnerType innerType; - MyModule(Integer value) { - this.value = value; + MyModule(Integer integer, InnerType innerType) { + this.integer = integer; + this.innerType = innerType; } @Provides @Nullable Integer provideInt() { - return value; + return integer; + } + + @Provides + @Nullable InnerType provideInnerType() { + return innerType; + } + } + + @Component(modules = DependencyModule.class) + interface ComponentDependency { + @Nullable Dependency dependency(); + } + + @Module + static class DependencyModule { + private final Dependency dependency; + + DependencyModule(Dependency dependency) { + this.dependency = dependency; + } + + @Provides + @Nullable Dependency provideDependency() { + return dependency; } } @Test public void testWithValue() { - MyComponent component = - DaggerJspecifyNullableTest_MyComponent.builder().myModule(new MyModule(15)).build(); + MyComponent component = DaggerJspecifyNullableTest_MyComponent.builder() + .myModule(new MyModule(15, new InnerType() {})) + .componentDependency( + DaggerJspecifyNullableTest_ComponentDependency.builder() + .dependencyModule(new DependencyModule(new Dependency() {})).build()) + .build(); assertThat(component.getInt()).isEqualTo(15); + assertThat(component.getInnerType()).isNotNull(); + assertThat(component.getDependencyProvider().get()).isNotNull(); } @Test public void testWithNull() { - MyComponent component = - DaggerJspecifyNullableTest_MyComponent.builder().myModule(new MyModule(null)).build(); + MyComponent component = DaggerJspecifyNullableTest_MyComponent.builder() + .myModule(new MyModule(null, null)) + .componentDependency( + DaggerJspecifyNullableTest_ComponentDependency.builder() + .dependencyModule(new DependencyModule(null)).build()) + .build(); NullPointerException expectedException = assertThrows(NullPointerException.class, component::getInt); assertThat(expectedException) .hasMessageThat() .contains("Cannot return null from a non-@Nullable @Provides method"); + NullPointerException expectedException2 = + assertThrows(NullPointerException.class, component::getInnerType); + assertThat(expectedException2) + .hasMessageThat() + .contains("Cannot return null from a non-@Nullable @Provides method"); + NullPointerException expectedException3 = + assertThrows(NullPointerException.class, () -> component.getDependencyProvider().get()); + assertThat(expectedException3) + .hasMessageThat() + .contains("Cannot return null from a non-@Nullable @Provides method"); } }